mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Added installation instructions to the error in _pallas_call_lowering
PiperOrigin-RevId: 621168804
This commit is contained in:
parent
4c41c12e21
commit
2ee4c0f644
@ -496,6 +496,15 @@ def _extract_function_name(f: Callable, name: str | None) -> str:
|
||||
return name
|
||||
|
||||
|
||||
def _unsupported_lowering_error(platform: str) -> Exception:
|
||||
return ValueError(
|
||||
f"Cannot lower pallas_call on platform: {platform}. To use Pallas on GPU,"
|
||||
" install jaxlib GPU 0.4.24 or newer. To use Pallas on TPU, install"
|
||||
" jaxlib TPU and libtpu. See"
|
||||
" https://jax.readthedocs.io/en/latest/installation.html."
|
||||
)
|
||||
|
||||
|
||||
def _pallas_call_lowering(
|
||||
ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params
|
||||
):
|
||||
@ -526,7 +535,7 @@ def _pallas_call_lowering(
|
||||
ctx, *in_nodes, interpret=interpret, **params
|
||||
)
|
||||
|
||||
raise ValueError(f"Cannot lower pallas_call on platform: {platform}.")
|
||||
raise _unsupported_lowering_error(platform)
|
||||
|
||||
|
||||
mlir.register_lowering(pallas_call_p, _pallas_call_lowering)
|
||||
|
Loading…
x
Reference in New Issue
Block a user