[pallas] Support for setting explicit backends to pallas_call.

PiperOrigin-RevId: 688511303
This commit is contained in:
Christos Perivolaropoulos 2024-10-22 05:36:37 -07:00 committed by jax authors
parent 2db03ba54b
commit 4f9356361a
3 changed files with 42 additions and 8 deletions

View File

@ -270,6 +270,7 @@ def _tensorcore_mesh_discharge_rule(
compiler_params=dict(
mosaic=dict(dimension_semantics=("parallel",)),
),
backend="mosaic_tpu",
)(*args)
return out, ()

View File

@ -512,6 +512,7 @@ def _gpu_mesh_discharge_rule(
out_specs=[any_spec] * len(in_avals),
input_output_aliases={i: i for i in range(len(in_avals))},
grid=tuple(mesh.shape.items()),
backend="mosaic_gpu",
)(*args)
return out, ()

View File

@ -19,7 +19,7 @@ from collections.abc import Callable, Iterable, Sequence
import dataclasses
from functools import partial, reduce
import itertools
from typing import Any
from typing import Any, Literal
import jax
from jax import lax
@ -362,6 +362,7 @@ def _pallas_call_jvp_rule(
compiler_params: Any,
cost_estimate: CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
backend: _Backend | None,
):
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
@ -425,7 +426,8 @@ def _pallas_call_jvp_rule(
input_output_aliases=(),
compiler_params=compiler_params,
cost_estimate=jvp_cost_estimate,
out_avals=(*out_avals, *out_avals)
out_avals=(*out_avals, *out_avals),
backend=backend,
)
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
return out_primals, out_tangents
@ -560,6 +562,7 @@ def _batch_with_explicit_loop(
compiler_params: Any,
cost_estimate: CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
backend: _Backend | None,
):
"""Batch the pallas_call by calling it in loop over the batch size.
@ -627,6 +630,7 @@ def _batch_with_explicit_loop(
compiler_params=compiler_params,
cost_estimate=cost_estimate,
out_avals=out_avals,
backend=backend,
)
for i, batch_out_array in enumerate(batch_out):
state[i] = jax.lax.dynamic_update_index_in_dim(
@ -656,6 +660,7 @@ def _pallas_call_batching_rule(
compiler_params: Any,
cost_estimate: CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
backend: _Backend | None,
):
def _maybe_squeeze_out_bdim(
x: jax.Array, bdim: int | batching.NotMapped
@ -688,6 +693,7 @@ def _pallas_call_batching_rule(
compiler_params=compiler_params,
cost_estimate=cost_estimate,
out_avals=out_avals,
backend=backend,
)
return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out)
@ -721,6 +727,7 @@ def _pallas_call_batching_rule(
compiler_params=compiler_params,
cost_estimate=cost_estimate,
out_avals=out_avals,
backend=backend,
)
else:
pass # No dynamic grid dimensions
@ -755,6 +762,7 @@ def _pallas_call_batching_rule(
compiler_params=compiler_params,
cost_estimate=cost_estimate,
out_avals=out_avals,
backend=backend,
)
if not dims:
@ -1128,6 +1136,7 @@ def _pallas_call_batching_rule(
compiler_params=compiler_params,
cost_estimate=batched_cost_estimate,
out_avals=batched_out_avals,
backend=backend,
)
return out, (0,) * len(out)
@ -1441,9 +1450,15 @@ def _unsupported_lowering_error(platform: str) -> Exception:
" https://jax.readthedocs.io/en/latest/installation.html."
)
_Backend = Literal["mosaic_tpu", "triton", "mosaic_gpu"]
def _pallas_call_lowering(
ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params
ctx: mlir.LoweringRuleContext,
*in_nodes,
interpret: bool,
backend: _Backend | None,
**params,
):
if params['jaxpr'].constvars:
raise ValueError('Cannot lower a pallas_call with constants.')
@ -1460,6 +1475,8 @@ def _pallas_call_lowering(
def tpu_lowering(ctx: mlir.LoweringRuleContext,
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
**params):
if backend and backend != "mosaic_tpu":
raise ValueError("Only mosaic backend supported for TPU")
if mosaic_tpu_backend is None:
raise _unsupported_lowering_error("tpu")
return mosaic_tpu_backend.pallas_call_tpu_lowering_rule(
@ -1470,12 +1487,21 @@ def _pallas_call_lowering(
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
**params):
try:
match backend:
case "mosaic_gpu":
from jax._src.pallas.mosaic_gpu import pallas_call_registration
case "triton":
from jax._src.pallas.triton import pallas_call_registration # type: ignore
case None:
if _PALLAS_USE_MOSAIC_GPU.value:
from jax._src.pallas.mosaic_gpu import pallas_call_registration
else:
from jax._src.pallas.triton import pallas_call_registration # type: ignore
except ImportError:
case _:
raise ValueError(f"Unsupported backend: {backend}")
except ImportError as e:
raise _unsupported_lowering_error("gpu")
return pallas_call_registration.pallas_call_lowering(
ctx, *in_nodes, **params
)
@ -1544,6 +1570,7 @@ def _pallas_call_state_discharge_rule(
compiler_params: Any,
cost_estimate: CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
backend: _Backend | None = None
):
del avals_out
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
@ -1645,6 +1672,7 @@ def _pallas_call_state_discharge_rule(
compiler_params=compiler_params,
cost_estimate=cost_estimate,
out_avals=new_out_avals,
backend=backend,
)
refs_out, rest = split_list(out_flat, [num_refs])
updated_vals_in = refs_out + [None] * len(rest_in_avals)
@ -1666,6 +1694,7 @@ def pallas_call(
name: str | None = None,
compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None,
cost_estimate: CostEstimate | None = None,
backend: _Backend | None = None,
) -> Callable[..., Any]:
"""Invokes a Pallas kernel on some inputs.
@ -1715,6 +1744,8 @@ def pallas_call(
platform is either 'mosaic' or 'triton'. It is also possible
to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and
`jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs.
backend: Optional string literal one of "mosaic_tpu", "triton" or "mosaic_gpu"
determining the backend to be used. None means let pallas decide.
Returns:
@ -1857,6 +1888,7 @@ def pallas_call(
input_output_aliases=tuple(input_output_aliases.items()),
compiler_params=compiler_params,
cost_estimate=cost_estimate,
backend=backend,
)
out = tree_util.tree_unflatten(out_tree, out_flat)
return out