Check if the Triton dialect bindings are available in lib/triton.py

IIRC we used to import these bindings in lib/__init__.py which is imported
as part of the top-level jax package. So, it did make sense to delay the
check until we actually need the bindings.

However, we have since moved the bindings to lib/triton.py and thus we could
move the check there.

PiperOrigin-RevId: 607196039
This commit is contained in:
Sergei Lebedev 2024-02-14 20:48:30 -08:00 committed by jax authors
parent 75104375ae
commit b4c8b0e4fb
2 changed files with 6 additions and 14 deletions

View File

@ -14,12 +14,11 @@
# ruff: noqa
from typing import Any
dialect: Any = None
try:
from jaxlib.triton import dialect # pytype: disable=import-error
except ImportError:
# TODO(slebedev): Switch to a jaxlib version guard, once Triton bindings
# are bundled with jaxlib.
pass
except ImportError as e:
raise ModuleNotFoundError(
"Cannot import the Triton bindings. You may need a newer version of"
" jaxlib. Try installing a nightly wheel following instructions in"
" https://jax.readthedocs.io/en/latest/installation.html#nightly-installation"
) from e

View File

@ -65,13 +65,6 @@ import triton.backends.nvidia.compiler as cb
# TODO(sharadmv): Enable type checking.
# mypy: ignore-errors
if tt_dialect is None:
raise RuntimeError(
"Cannot import the Triton bindings. You may need a newer version of"
" jaxlib. Try installing a nightly wheel following instructions in"
" https://jax.readthedocs.io/en/latest/installation.html#nightly-installation"
)
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
partial = functools.partial