mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use jaxlib version guard for triton instead of xla_extension_version
PiperOrigin-RevId: 534974834
This commit is contained in:
parent
6a54ebd031
commit
4fb834b351
@ -116,8 +116,13 @@ import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
|
||||
xla_extension_version: int = getattr(xla_client, '_version', 0)
|
||||
|
||||
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
|
||||
if xla_extension_version >= 154:
|
||||
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error
|
||||
if jaxlib.version.__version_info__ >= (0, 4, 11):
|
||||
# TODO(sharadmv): make this unconditional when minimum jaxlib version is
|
||||
# bumped to 0.4.11
|
||||
try:
|
||||
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
# Version number for MLIR:Python APIs, provided by jaxlib.
|
||||
mlir_api_version = xla_client.mlir_api_version
|
||||
|
Loading…
x
Reference in New Issue
Block a user