diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 3fff55d9e..caff1c731 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -429,7 +429,9 @@ def stablehlo_version_at_least(required_version: str): plugin_version = xla_bridge.backend_stablehlo_version() if plugin_version is None: return True - return hlo.get_smaller_version(plugin_version, required_version) == plugin_version + return hlo.get_smaller_version( + ".".join(map(str, plugin_version)), required_version + ) == plugin_version def get_tpu_version() -> int: if device_under_test() != "tpu": diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index c23510183..5fb42c333 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -32,7 +32,7 @@ import pkgutil import platform as py_platform import threading import traceback -from typing import Any, Union +from typing import Any, Sequence, Union import warnings from jax._src import config @@ -1086,7 +1086,7 @@ def backend_xla_version(platform=None) -> int | None: backend = get_backend(platform) return getattr(backend, "xla_version", None) -def backend_stablehlo_version(platform=None) -> int | None: +def backend_stablehlo_version(platform=None) -> Sequence[int] | None: """Returns the StableHLO version of the backend. Returns None if the backend does not use PJRT C API or does not have