mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[pallas] pallas_call_p
is now parameterized by a mesh
The mesh is necessary to add support for clusters to the Mosaic GPU backend. PiperOrigin-RevId: 737792129
This commit is contained in:
parent
b4966130a3
commit
051687dc4c
@ -15,6 +15,7 @@
|
||||
"""Module for pallas-core functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
import contextlib
|
||||
import copy
|
||||
@ -1068,6 +1069,17 @@ def _core_map_abstract_eval(*args, jaxpr, mesh, **_):
|
||||
return [], effs
|
||||
|
||||
|
||||
class Mesh(Protocol):
|
||||
|
||||
@property
|
||||
def backend(self) -> str:
|
||||
...
|
||||
|
||||
@property
|
||||
def shape(self) -> collections.OrderedDict[object, int]:
|
||||
...
|
||||
|
||||
|
||||
_core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {}
|
||||
|
||||
|
||||
@ -1075,9 +1087,8 @@ def default_mesh_discharge_rule(
|
||||
in_avals,
|
||||
out_avals,
|
||||
*args,
|
||||
grid,
|
||||
mesh,
|
||||
compiler_params,
|
||||
backend,
|
||||
jaxpr,
|
||||
debug,
|
||||
interpret,
|
||||
@ -1100,19 +1111,22 @@ def default_mesh_discharge_rule(
|
||||
if isinstance(eff, state_types.WriteEffect)
|
||||
)
|
||||
any_spec = BlockSpec(memory_space=MemorySpace.ANY)
|
||||
grid_spec = GridSpec(
|
||||
grid=tuple(mesh.shape.items()),
|
||||
in_specs=[any_spec] * len(in_avals),
|
||||
out_specs=[any_spec] * len(modified_idxs),
|
||||
)
|
||||
from jax._src.pallas import pallas_call # Avoid circular dependency.
|
||||
outs = pallas_call.pallas_call(
|
||||
outs = pallas_call._pallas_call(
|
||||
body,
|
||||
name=name,
|
||||
out_shape=[in_avals[idx] for idx in modified_idxs],
|
||||
in_specs=[any_spec] * len(in_avals),
|
||||
out_specs=[any_spec] * len(modified_idxs),
|
||||
input_output_aliases={
|
||||
in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs)
|
||||
},
|
||||
grid=grid,
|
||||
grid_spec=grid_spec,
|
||||
mesh=mesh,
|
||||
compiler_params=compiler_params,
|
||||
backend=backend,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
cost_estimate=cost_estimate,
|
||||
|
@ -340,11 +340,12 @@ def pallas_call_hlo_interpret(
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
):
|
||||
del compiler_params, cost_estimate, out_avals
|
||||
del mesh, compiler_params, cost_estimate, out_avals
|
||||
debug_info = jaxpr.debug_info
|
||||
# If we're in interpret mode, we *scan* over the grid and eval the
|
||||
# discharged jaxpr.
|
||||
|
@ -211,6 +211,10 @@ class TensorCoreMesh:
|
||||
devices: np.ndarray
|
||||
axis_names: Sequence[str]
|
||||
|
||||
@property
|
||||
def backend(self) -> str:
|
||||
return "mosaic_tpu"
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return collections.OrderedDict(zip(self.axis_names, self.devices.shape))
|
||||
@ -259,7 +263,6 @@ def _tensorcore_mesh_discharge_rule(
|
||||
compiler_params = TPUCompilerParams()
|
||||
if len(mesh.shape) > 1:
|
||||
raise NotImplementedError("Mesh must be 1D")
|
||||
core_axis_name, num_cores = list(mesh.shape.items())[0]
|
||||
if compiler_params.dimension_semantics is not None:
|
||||
raise ValueError(
|
||||
"dimension_semantics must be None for TensorCoreMesh"
|
||||
@ -269,13 +272,12 @@ def _tensorcore_mesh_discharge_rule(
|
||||
out_avals,
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
grid=((core_axis_name, num_cores),),
|
||||
mesh=mesh,
|
||||
compiler_params=compiler_params.replace(
|
||||
dimension_semantics=(PARALLEL,)
|
||||
),
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
backend="mosaic_tpu",
|
||||
cost_estimate=cost_estimate,
|
||||
name=name,
|
||||
)
|
||||
|
@ -1351,12 +1351,13 @@ def interpret_pallas_call(
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
interpret_params: TPUInterpretParams,
|
||||
):
|
||||
del debug, cost_estimate, out_avals
|
||||
del debug, mesh, cost_estimate, out_avals
|
||||
|
||||
# args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?)
|
||||
dynamic_grid_args, scalars, input_args = split_list(
|
||||
|
@ -108,6 +108,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
*in_nodes,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
grid_mapping: core.GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
@ -116,7 +117,8 @@ def pallas_call_tpu_lowering_rule(
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
):
|
||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||
del interpret
|
||||
del mesh, interpret # Unused.
|
||||
|
||||
debug_info = jaxpr._debug_info
|
||||
if debug:
|
||||
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
|
||||
@ -126,11 +128,11 @@ def pallas_call_tpu_lowering_rule(
|
||||
else:
|
||||
mosaic_params = {}
|
||||
|
||||
mesh = None
|
||||
jax_mesh = None
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if axis_context is not None:
|
||||
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
mesh = axis_context.mesh
|
||||
jax_mesh = axis_context.mesh
|
||||
mlir_ctx = mlir.JaxIrContext()
|
||||
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
|
||||
mlir_ctx.load_all_available_dialects()
|
||||
@ -147,7 +149,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
grid_mapping,
|
||||
jaxpr,
|
||||
dimension_semantics=dimension_semantics,
|
||||
mesh=mesh,
|
||||
mesh=jax_mesh,
|
||||
for_verification=for_verification,
|
||||
dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(),
|
||||
)
|
||||
@ -164,11 +166,11 @@ def pallas_call_tpu_lowering_rule(
|
||||
)
|
||||
|
||||
if promela_dump_path := _DUMP_PROMELA_TO.value:
|
||||
num_devices = 1 if mesh is None else mesh.devices.size
|
||||
num_devices = 1 if jax_mesh is None else jax_mesh.devices.size
|
||||
num_cores = (
|
||||
jax.devices()[0].num_cores
|
||||
if mesh is None
|
||||
else mesh.devices[0].num_cores
|
||||
if jax_mesh is None
|
||||
else jax_mesh.devices[0].num_cores
|
||||
)
|
||||
verification_module, _ = lower_module(for_verification=True)
|
||||
model = verification.export_promela_model(
|
||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import collections
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Iterable, Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import itertools as it
|
||||
@ -519,9 +519,16 @@ class GPUMesh:
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
def backend(self) -> str:
|
||||
return "mosaic_gpu"
|
||||
|
||||
@property
|
||||
def shape(self) -> collections.OrderedDict[object, int]:
|
||||
pairs: Iterable[tuple[object, int]]
|
||||
if self.num_threads is not None:
|
||||
pairs = zip(self.axis_names, (*self.grid, *self.cluster, self.num_threads))
|
||||
pairs = zip(
|
||||
self.axis_names, (*self.grid, *self.cluster, self.num_threads)
|
||||
)
|
||||
else:
|
||||
pairs = tuple(
|
||||
zip(
|
||||
@ -563,8 +570,7 @@ def _gpu_mesh_discharge_rule(
|
||||
out_avals,
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
grid=tuple(mesh.shape.items()),
|
||||
backend="mosaic_gpu",
|
||||
mesh=mesh,
|
||||
compiler_params=compiler_params,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
|
@ -450,6 +450,7 @@ def _block_spec_from_block_mapping(
|
||||
|
||||
def lower_pipelined_jaxpr_to_module(
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
compiler_params: dict[str, Any],
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
@ -473,7 +474,10 @@ def lower_pipelined_jaxpr_to_module(
|
||||
block_mappings, [grid_mapping.num_inputs]
|
||||
)
|
||||
|
||||
if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count
|
||||
if mesh is not None:
|
||||
assert isinstance(mesh, gpu_core.GPUMesh)
|
||||
if mesh and mesh.num_threads is not None:
|
||||
# Last dim corresponds to the warpgroup count.
|
||||
block = (128 * grid_mapping.grid[-1], 1, 1)
|
||||
grid = grid_mapping.grid[:-1]
|
||||
else:
|
||||
@ -566,6 +570,7 @@ def lower_pipelined_jaxpr_to_module(
|
||||
parallel_grid,
|
||||
grid_mapping.grid_names,
|
||||
block,
|
||||
mesh.cluster if mesh is not None else (),
|
||||
[bm.array_shape_dtype for bm in in_block_mappings],
|
||||
[bm.array_shape_dtype for bm in out_block_mappings],
|
||||
new_jaxpr,
|
||||
@ -578,6 +583,7 @@ def lower_jaxpr_to_module(
|
||||
grid: Sequence[int],
|
||||
grid_names: Sequence[str],
|
||||
block: Sequence[int],
|
||||
cluster: Sequence[int],
|
||||
in_shapes: Sequence[jax.ShapeDtypeStruct],
|
||||
out_shapes: Sequence[jax.ShapeDtypeStruct],
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
@ -640,7 +646,7 @@ def lower_jaxpr_to_module(
|
||||
mgpu_core._lower_as_gpu_kernel(
|
||||
body,
|
||||
grid=parallel_grid,
|
||||
cluster=(),
|
||||
cluster=cluster,
|
||||
block=block,
|
||||
in_shapes=in_shapes,
|
||||
out_shape=out_shapes,
|
||||
|
@ -38,6 +38,7 @@ def pallas_call_lowering(
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
compiler_params: dict[str, Any],
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
@ -63,6 +64,7 @@ def pallas_call_lowering(
|
||||
|
||||
lowering_result = lowering.lower_pipelined_jaxpr_to_module(
|
||||
grid_mapping,
|
||||
mesh,
|
||||
jaxpr,
|
||||
compiler_params,
|
||||
cost_estimate,
|
||||
|
@ -20,7 +20,7 @@ import dataclasses
|
||||
import enum
|
||||
from functools import partial, reduce
|
||||
import types
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -119,6 +119,7 @@ def _pallas_call_jvp_rule(
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
compiler_params: Any,
|
||||
@ -133,6 +134,8 @@ def _pallas_call_jvp_rule(
|
||||
raise NotImplementedError
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError("JVP with aliasing not supported.")
|
||||
if mesh is not None:
|
||||
raise NotImplementedError("pallas_call with a mesh does not support JVP")
|
||||
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
|
||||
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
||||
nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs
|
||||
@ -181,6 +184,7 @@ def _pallas_call_jvp_rule(
|
||||
*tangents,
|
||||
jaxpr=jvp_jaxpr,
|
||||
grid_mapping=jvp_grid_mapping,
|
||||
mesh=mesh,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=(),
|
||||
@ -317,6 +321,7 @@ def _batch_with_explicit_loop(
|
||||
*,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
grid_mapping: GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
@ -384,6 +389,7 @@ def _batch_with_explicit_loop(
|
||||
*batch_args,
|
||||
jaxpr=jaxpr,
|
||||
grid_mapping=grid_mapping,
|
||||
mesh=mesh,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -413,6 +419,7 @@ def _pallas_call_batching_rule(
|
||||
*,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
grid_mapping: GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
@ -421,6 +428,11 @@ def _pallas_call_batching_rule(
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
backend: _Backend | None,
|
||||
):
|
||||
if mesh is not None:
|
||||
raise NotImplementedError(
|
||||
"pallas_call with a mesh does not support batching"
|
||||
)
|
||||
|
||||
def _maybe_squeeze_out_bdim(
|
||||
x: jax.Array, bdim: int | batching.NotMapped
|
||||
) -> jax.Array:
|
||||
@ -445,6 +457,7 @@ def _pallas_call_batching_rule(
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
grid_mapping=grid_mapping,
|
||||
mesh=mesh,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -478,6 +491,7 @@ def _pallas_call_batching_rule(
|
||||
dims=dynamic_grid_dims + dims,
|
||||
jaxpr=jaxpr,
|
||||
grid_mapping=grid_mapping,
|
||||
mesh=mesh,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -512,6 +526,7 @@ def _pallas_call_batching_rule(
|
||||
dims=scalar_bdims + bdims,
|
||||
jaxpr=jaxpr,
|
||||
grid_mapping=grid_mapping,
|
||||
mesh=mesh,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -890,6 +905,7 @@ def _pallas_call_batching_rule(
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
grid_mapping=batched_grid_mapping,
|
||||
mesh=mesh,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -1339,12 +1355,13 @@ def _pallas_call_state_discharge_rule(
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
backend: _Backend | None = None
|
||||
backend: _Backend | None = None,
|
||||
):
|
||||
del avals_out
|
||||
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
|
||||
@ -1440,6 +1457,7 @@ def _pallas_call_state_discharge_rule(
|
||||
jaxpr=new_jaxpr,
|
||||
input_output_aliases=new_input_output_aliases,
|
||||
grid_mapping=new_grid_mapping,
|
||||
mesh=mesh,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=compiler_params,
|
||||
@ -1526,16 +1544,6 @@ def pallas_call(
|
||||
invoke the Pallas kernel.
|
||||
|
||||
"""
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
if isinstance(compiler_params, pallas_core.CompilerParams):
|
||||
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
|
||||
raise ValueError(
|
||||
f"Unknown platform in compiler params: {compiler_params.PLATFORM}")
|
||||
compiler_params = {
|
||||
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
|
||||
}
|
||||
|
||||
if grid_spec is None:
|
||||
grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes)
|
||||
else:
|
||||
@ -1556,6 +1564,55 @@ def pallas_call(
|
||||
"If `grid_spec` is specified, then `scratch_shapes` must "
|
||||
f"be `()`. It is {scratch_shapes}")
|
||||
del grid, in_specs, out_specs
|
||||
return _pallas_call(
|
||||
kernel,
|
||||
out_shape,
|
||||
grid_spec=grid_spec,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
name=name,
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
|
||||
def _pallas_call(
|
||||
kernel: Callable[..., None],
|
||||
out_shape: Any,
|
||||
*,
|
||||
grid_spec: GridSpec,
|
||||
mesh: pallas_core.Mesh | None = None,
|
||||
input_output_aliases: dict[int, int] = {},
|
||||
debug: bool = False,
|
||||
interpret: bool = False,
|
||||
name: str | None = None,
|
||||
compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None,
|
||||
cost_estimate: CostEstimate | None = None,
|
||||
backend: _Backend | None = None,
|
||||
):
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
if isinstance(compiler_params, pallas_core.CompilerParams):
|
||||
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
|
||||
raise ValueError(
|
||||
f"Unknown platform in compiler params: {compiler_params.PLATFORM}"
|
||||
)
|
||||
compiler_params = {
|
||||
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
|
||||
}
|
||||
|
||||
if mesh is not None:
|
||||
if tuple(mesh.shape.values()) != grid_spec.grid:
|
||||
raise ValueError(
|
||||
f"Mesh shape {tuple(mesh.shape.values())} does not match grid "
|
||||
f"shape {grid_spec.grid}."
|
||||
)
|
||||
if backend is not None:
|
||||
raise ValueError("If `mesh` is specified, then `backend` must be `None`.")
|
||||
backend = cast(_Backend, mesh.backend)
|
||||
|
||||
grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec)
|
||||
# TODO(necula): this canonicalization may be convenient for some usage
|
||||
# but it is lossy, because it prevents expressing functions that return
|
||||
@ -1643,6 +1700,7 @@ def pallas_call(
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
grid_mapping=grid_mapping,
|
||||
mesh=mesh,
|
||||
input_output_aliases=tuple(input_output_aliases.items()),
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
|
@ -50,6 +50,7 @@ def pallas_call_lowering(
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
compiler_params: dict[str, Any],
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
@ -64,6 +65,8 @@ def pallas_call_lowering(
|
||||
raise NotImplementedError(
|
||||
"scalar prefetch not implemented in the Triton backend"
|
||||
)
|
||||
if mesh is not None:
|
||||
raise NotImplementedError("mesh is not supported in the Triton backend")
|
||||
triton_params = compiler_params.get("triton", compiler_params)
|
||||
num_warps = triton_params.get("num_warps", 4)
|
||||
num_warps = 4 if num_warps is None else num_warps
|
||||
|
Loading…
x
Reference in New Issue
Block a user