Added installation instructions to the error in _pallas_call_lowering

PiperOrigin-RevId: 621168804
This commit is contained in:
Sergei Lebedev 2024-04-02 07:35:40 -07:00 committed by jax authors
parent 4c41c12e21
commit 2ee4c0f644

View File

@ -496,6 +496,15 @@ def _extract_function_name(f: Callable, name: str | None) -> str:
return name
def _unsupported_lowering_error(platform: str) -> Exception:
return ValueError(
f"Cannot lower pallas_call on platform: {platform}. To use Pallas on GPU,"
" install jaxlib GPU 0.4.24 or newer. To use Pallas on TPU, install"
" jaxlib TPU and libtpu. See"
" https://jax.readthedocs.io/en/latest/installation.html."
)
def _pallas_call_lowering(
ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params
):
@ -526,7 +535,7 @@ def _pallas_call_lowering(
ctx, *in_nodes, interpret=interpret, **params
)
raise ValueError(f"Cannot lower pallas_call on platform: {platform}.")
raise _unsupported_lowering_error(platform)
mlir.register_lowering(pallas_call_p, _pallas_call_lowering)