Use jaxlib version guard for triton instead of xla_extension_version

PiperOrigin-RevId: 534974834
This commit is contained in:
Sharad Vikram 2023-05-24 14:06:10 -07:00 committed by jax authors
parent 6a54ebd031
commit 4fb834b351

View File

@ -116,8 +116,13 @@ import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
xla_extension_version: int = getattr(xla_client, '_version', 0)
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
if xla_extension_version >= 154:
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error
if jaxlib.version.__version_info__ >= (0, 4, 11):
# TODO(sharadmv): make this unconditional when minimum jaxlib version is
# bumped to 0.4.11
try:
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error
except ModuleNotFoundError:
pass
# Version number for MLIR:Python APIs, provided by jaxlib.
mlir_api_version = xla_client.mlir_api_version