Fix stablehlo version comparison in test utilities.

PiperOrigin-RevId: 747547427
This commit is contained in:
Peter Hawkins 2025-04-14 13:32:50 -07:00 committed by jax authors
parent d014912671
commit 8930a67e63
2 changed files with 5 additions and 3 deletions

View File

@ -429,7 +429,9 @@ def stablehlo_version_at_least(required_version: str):
plugin_version = xla_bridge.backend_stablehlo_version()
if plugin_version is None:
return True
return hlo.get_smaller_version(plugin_version, required_version) == plugin_version
return hlo.get_smaller_version(
".".join(map(str, plugin_version)), required_version
) == plugin_version
def get_tpu_version() -> int:
if device_under_test() != "tpu":

View File

@ -32,7 +32,7 @@ import pkgutil
import platform as py_platform
import threading
import traceback
from typing import Any, Union
from typing import Any, Sequence, Union
import warnings
from jax._src import config
@ -1086,7 +1086,7 @@ def backend_xla_version(platform=None) -> int | None:
backend = get_backend(platform)
return getattr(backend, "xla_version", None)
def backend_stablehlo_version(platform=None) -> int | None:
def backend_stablehlo_version(platform=None) -> Sequence[int] | None:
"""Returns the StableHLO version of the backend.
Returns None if the backend does not use PJRT C API or does not have