diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 3407fe2e3..82fe9c2ba 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -270,6 +270,7 @@ def _tensorcore_mesh_discharge_rule( compiler_params=dict( mosaic=dict(dimension_semantics=("parallel",)), ), + backend="mosaic_tpu", )(*args) return out, () diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index c4e785e50..2ed8910bf 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -512,6 +512,7 @@ def _gpu_mesh_discharge_rule( out_specs=[any_spec] * len(in_avals), input_output_aliases={i: i for i in range(len(in_avals))}, grid=tuple(mesh.shape.items()), + backend="mosaic_gpu", )(*args) return out, () diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 4db6057d2..2bed4a083 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -19,7 +19,7 @@ from collections.abc import Callable, Iterable, Sequence import dataclasses from functools import partial, reduce import itertools -from typing import Any +from typing import Any, Literal import jax from jax import lax @@ -362,6 +362,7 @@ def _pallas_call_jvp_rule( compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], + backend: _Backend | None, ): if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError("interpret with dynamic grid bounds unsupported") @@ -425,7 +426,8 @@ def _pallas_call_jvp_rule( input_output_aliases=(), compiler_params=compiler_params, cost_estimate=jvp_cost_estimate, - out_avals=(*out_avals, *out_avals) + out_avals=(*out_avals, *out_avals), + backend=backend, ) out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2]) return out_primals, out_tangents @@ -560,6 +562,7 @@ def _batch_with_explicit_loop( compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], + backend: _Backend | None, ): """Batch the pallas_call by calling it in loop over the batch size. @@ -627,6 +630,7 @@ def _batch_with_explicit_loop( compiler_params=compiler_params, cost_estimate=cost_estimate, out_avals=out_avals, + backend=backend, ) for i, batch_out_array in enumerate(batch_out): state[i] = jax.lax.dynamic_update_index_in_dim( @@ -656,6 +660,7 @@ def _pallas_call_batching_rule( compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], + backend: _Backend | None, ): def _maybe_squeeze_out_bdim( x: jax.Array, bdim: int | batching.NotMapped @@ -688,6 +693,7 @@ def _pallas_call_batching_rule( compiler_params=compiler_params, cost_estimate=cost_estimate, out_avals=out_avals, + backend=backend, ) return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out) @@ -721,6 +727,7 @@ def _pallas_call_batching_rule( compiler_params=compiler_params, cost_estimate=cost_estimate, out_avals=out_avals, + backend=backend, ) else: pass # No dynamic grid dimensions @@ -755,6 +762,7 @@ def _pallas_call_batching_rule( compiler_params=compiler_params, cost_estimate=cost_estimate, out_avals=out_avals, + backend=backend, ) if not dims: @@ -1128,6 +1136,7 @@ def _pallas_call_batching_rule( compiler_params=compiler_params, cost_estimate=batched_cost_estimate, out_avals=batched_out_avals, + backend=backend, ) return out, (0,) * len(out) @@ -1441,9 +1450,15 @@ def _unsupported_lowering_error(platform: str) -> Exception: " https://jax.readthedocs.io/en/latest/installation.html." ) +_Backend = Literal["mosaic_tpu", "triton", "mosaic_gpu"] + def _pallas_call_lowering( - ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params + ctx: mlir.LoweringRuleContext, + *in_nodes, + interpret: bool, + backend: _Backend | None, + **params, ): if params['jaxpr'].constvars: raise ValueError('Cannot lower a pallas_call with constants.') @@ -1460,6 +1475,8 @@ def _pallas_call_lowering( def tpu_lowering(ctx: mlir.LoweringRuleContext, *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], **params): + if backend and backend != "mosaic_tpu": + raise ValueError("Only mosaic backend supported for TPU") if mosaic_tpu_backend is None: raise _unsupported_lowering_error("tpu") return mosaic_tpu_backend.pallas_call_tpu_lowering_rule( @@ -1470,12 +1487,21 @@ def _pallas_call_lowering( *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], **params): try: - if _PALLAS_USE_MOSAIC_GPU.value: - from jax._src.pallas.mosaic_gpu import pallas_call_registration - else: - from jax._src.pallas.triton import pallas_call_registration # type: ignore - except ImportError: + match backend: + case "mosaic_gpu": + from jax._src.pallas.mosaic_gpu import pallas_call_registration + case "triton": + from jax._src.pallas.triton import pallas_call_registration # type: ignore + case None: + if _PALLAS_USE_MOSAIC_GPU.value: + from jax._src.pallas.mosaic_gpu import pallas_call_registration + else: + from jax._src.pallas.triton import pallas_call_registration # type: ignore + case _: + raise ValueError(f"Unsupported backend: {backend}") + except ImportError as e: raise _unsupported_lowering_error("gpu") + return pallas_call_registration.pallas_call_lowering( ctx, *in_nodes, **params ) @@ -1544,6 +1570,7 @@ def _pallas_call_state_discharge_rule( compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], + backend: _Backend | None = None ): del avals_out assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars) @@ -1645,6 +1672,7 @@ def _pallas_call_state_discharge_rule( compiler_params=compiler_params, cost_estimate=cost_estimate, out_avals=new_out_avals, + backend=backend, ) refs_out, rest = split_list(out_flat, [num_refs]) updated_vals_in = refs_out + [None] * len(rest_in_avals) @@ -1666,6 +1694,7 @@ def pallas_call( name: str | None = None, compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None, cost_estimate: CostEstimate | None = None, + backend: _Backend | None = None, ) -> Callable[..., Any]: """Invokes a Pallas kernel on some inputs. @@ -1715,6 +1744,8 @@ def pallas_call( platform is either 'mosaic' or 'triton'. It is also possible to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and `jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs. + backend: Optional string literal one of "mosaic_tpu", "triton" or "mosaic_gpu" + determining the backend to be used. None means let pallas decide. Returns: @@ -1857,6 +1888,7 @@ def pallas_call( input_output_aliases=tuple(input_output_aliases.items()), compiler_params=compiler_params, cost_estimate=cost_estimate, + backend=backend, ) out = tree_util.tree_unflatten(out_tree, out_flat) return out