mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Check if the Triton dialect bindings are available in lib/triton.py
IIRC we used to import these bindings in lib/__init__.py which is imported as part of the top-level jax package. So, it did make sense to delay the check until we actually need the bindings. However, we have since moved the bindings to lib/triton.py and thus we could move the check there. PiperOrigin-RevId: 607196039
This commit is contained in:
parent
75104375ae
commit
b4c8b0e4fb
@ -14,12 +14,11 @@
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
from typing import Any
|
||||
|
||||
dialect: Any = None
|
||||
try:
|
||||
from jaxlib.triton import dialect # pytype: disable=import-error
|
||||
except ImportError:
|
||||
# TODO(slebedev): Switch to a jaxlib version guard, once Triton bindings
|
||||
# are bundled with jaxlib.
|
||||
pass
|
||||
except ImportError as e:
|
||||
raise ModuleNotFoundError(
|
||||
"Cannot import the Triton bindings. You may need a newer version of"
|
||||
" jaxlib. Try installing a nightly wheel following instructions in"
|
||||
" https://jax.readthedocs.io/en/latest/installation.html#nightly-installation"
|
||||
) from e
|
||||
|
@ -65,13 +65,6 @@ import triton.backends.nvidia.compiler as cb
|
||||
# TODO(sharadmv): Enable type checking.
|
||||
# mypy: ignore-errors
|
||||
|
||||
if tt_dialect is None:
|
||||
raise RuntimeError(
|
||||
"Cannot import the Triton bindings. You may need a newer version of"
|
||||
" jaxlib. Try installing a nightly wheel following instructions in"
|
||||
" https://jax.readthedocs.io/en/latest/installation.html#nightly-installation"
|
||||
)
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
partial = functools.partial
|
||||
|
Loading…
x
Reference in New Issue
Block a user