mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[pallas] Support for setting explicit backends to pallas_call.
PiperOrigin-RevId: 688511303
This commit is contained in:
parent
2db03ba54b
commit
4f9356361a
@ -270,6 +270,7 @@ def _tensorcore_mesh_discharge_rule(
|
||||
compiler_params=dict(
|
||||
mosaic=dict(dimension_semantics=("parallel",)),
|
||||
),
|
||||
backend="mosaic_tpu",
|
||||
)(*args)
|
||||
return out, ()
|
||||
|
||||
|
@ -512,6 +512,7 @@ def _gpu_mesh_discharge_rule(
|
||||
out_specs=[any_spec] * len(in_avals),
|
||||
input_output_aliases={i: i for i in range(len(in_avals))},
|
||||
grid=tuple(mesh.shape.items()),
|
||||
backend="mosaic_gpu",
|
||||
)(*args)
|
||||
return out, ()
|
||||
|
||||
|
@ -19,7 +19,7 @@ from collections.abc import Callable, Iterable, Sequence
|
||||
import dataclasses
|
||||
from functools import partial, reduce
|
||||
import itertools
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -362,6 +362,7 @@ def _pallas_call_jvp_rule(
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
backend: _Backend | None,
|
||||
):
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
|
||||
@ -425,7 +426,8 @@ def _pallas_call_jvp_rule(
|
||||
input_output_aliases=(),
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=jvp_cost_estimate,
|
||||
out_avals=(*out_avals, *out_avals)
|
||||
out_avals=(*out_avals, *out_avals),
|
||||
backend=backend,
|
||||
)
|
||||
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
|
||||
return out_primals, out_tangents
|
||||
@ -560,6 +562,7 @@ def _batch_with_explicit_loop(
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
backend: _Backend | None,
|
||||
):
|
||||
"""Batch the pallas_call by calling it in loop over the batch size.
|
||||
|
||||
@ -627,6 +630,7 @@ def _batch_with_explicit_loop(
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
out_avals=out_avals,
|
||||
backend=backend,
|
||||
)
|
||||
for i, batch_out_array in enumerate(batch_out):
|
||||
state[i] = jax.lax.dynamic_update_index_in_dim(
|
||||
@ -656,6 +660,7 @@ def _pallas_call_batching_rule(
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
backend: _Backend | None,
|
||||
):
|
||||
def _maybe_squeeze_out_bdim(
|
||||
x: jax.Array, bdim: int | batching.NotMapped
|
||||
@ -688,6 +693,7 @@ def _pallas_call_batching_rule(
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
out_avals=out_avals,
|
||||
backend=backend,
|
||||
)
|
||||
return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out)
|
||||
|
||||
@ -721,6 +727,7 @@ def _pallas_call_batching_rule(
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
out_avals=out_avals,
|
||||
backend=backend,
|
||||
)
|
||||
else:
|
||||
pass # No dynamic grid dimensions
|
||||
@ -755,6 +762,7 @@ def _pallas_call_batching_rule(
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
out_avals=out_avals,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
if not dims:
|
||||
@ -1128,6 +1136,7 @@ def _pallas_call_batching_rule(
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=batched_cost_estimate,
|
||||
out_avals=batched_out_avals,
|
||||
backend=backend,
|
||||
)
|
||||
return out, (0,) * len(out)
|
||||
|
||||
@ -1441,9 +1450,15 @@ def _unsupported_lowering_error(platform: str) -> Exception:
|
||||
" https://jax.readthedocs.io/en/latest/installation.html."
|
||||
)
|
||||
|
||||
_Backend = Literal["mosaic_tpu", "triton", "mosaic_gpu"]
|
||||
|
||||
|
||||
def _pallas_call_lowering(
|
||||
ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*in_nodes,
|
||||
interpret: bool,
|
||||
backend: _Backend | None,
|
||||
**params,
|
||||
):
|
||||
if params['jaxpr'].constvars:
|
||||
raise ValueError('Cannot lower a pallas_call with constants.')
|
||||
@ -1460,6 +1475,8 @@ def _pallas_call_lowering(
|
||||
def tpu_lowering(ctx: mlir.LoweringRuleContext,
|
||||
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
||||
**params):
|
||||
if backend and backend != "mosaic_tpu":
|
||||
raise ValueError("Only mosaic backend supported for TPU")
|
||||
if mosaic_tpu_backend is None:
|
||||
raise _unsupported_lowering_error("tpu")
|
||||
return mosaic_tpu_backend.pallas_call_tpu_lowering_rule(
|
||||
@ -1470,12 +1487,21 @@ def _pallas_call_lowering(
|
||||
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
||||
**params):
|
||||
try:
|
||||
match backend:
|
||||
case "mosaic_gpu":
|
||||
from jax._src.pallas.mosaic_gpu import pallas_call_registration
|
||||
case "triton":
|
||||
from jax._src.pallas.triton import pallas_call_registration # type: ignore
|
||||
case None:
|
||||
if _PALLAS_USE_MOSAIC_GPU.value:
|
||||
from jax._src.pallas.mosaic_gpu import pallas_call_registration
|
||||
else:
|
||||
from jax._src.pallas.triton import pallas_call_registration # type: ignore
|
||||
except ImportError:
|
||||
case _:
|
||||
raise ValueError(f"Unsupported backend: {backend}")
|
||||
except ImportError as e:
|
||||
raise _unsupported_lowering_error("gpu")
|
||||
|
||||
return pallas_call_registration.pallas_call_lowering(
|
||||
ctx, *in_nodes, **params
|
||||
)
|
||||
@ -1544,6 +1570,7 @@ def _pallas_call_state_discharge_rule(
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
backend: _Backend | None = None
|
||||
):
|
||||
del avals_out
|
||||
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
|
||||
@ -1645,6 +1672,7 @@ def _pallas_call_state_discharge_rule(
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
out_avals=new_out_avals,
|
||||
backend=backend,
|
||||
)
|
||||
refs_out, rest = split_list(out_flat, [num_refs])
|
||||
updated_vals_in = refs_out + [None] * len(rest_in_avals)
|
||||
@ -1666,6 +1694,7 @@ def pallas_call(
|
||||
name: str | None = None,
|
||||
compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None,
|
||||
cost_estimate: CostEstimate | None = None,
|
||||
backend: _Backend | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Invokes a Pallas kernel on some inputs.
|
||||
|
||||
@ -1715,6 +1744,8 @@ def pallas_call(
|
||||
platform is either 'mosaic' or 'triton'. It is also possible
|
||||
to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and
|
||||
`jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs.
|
||||
backend: Optional string literal one of "mosaic_tpu", "triton" or "mosaic_gpu"
|
||||
determining the backend to be used. None means let pallas decide.
|
||||
|
||||
|
||||
Returns:
|
||||
@ -1857,6 +1888,7 @@ def pallas_call(
|
||||
input_output_aliases=tuple(input_output_aliases.items()),
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
backend=backend,
|
||||
)
|
||||
out = tree_util.tree_unflatten(out_tree, out_flat)
|
||||
return out
|
||||
|
Loading…
x
Reference in New Issue
Block a user