mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix stablehlo version comparison in test utilities.
PiperOrigin-RevId: 747547427
This commit is contained in:
parent
d014912671
commit
8930a67e63
@ -429,7 +429,9 @@ def stablehlo_version_at_least(required_version: str):
|
||||
plugin_version = xla_bridge.backend_stablehlo_version()
|
||||
if plugin_version is None:
|
||||
return True
|
||||
return hlo.get_smaller_version(plugin_version, required_version) == plugin_version
|
||||
return hlo.get_smaller_version(
|
||||
".".join(map(str, plugin_version)), required_version
|
||||
) == plugin_version
|
||||
|
||||
def get_tpu_version() -> int:
|
||||
if device_under_test() != "tpu":
|
||||
|
@ -32,7 +32,7 @@ import pkgutil
|
||||
import platform as py_platform
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Any, Union
|
||||
from typing import Any, Sequence, Union
|
||||
import warnings
|
||||
|
||||
from jax._src import config
|
||||
@ -1086,7 +1086,7 @@ def backend_xla_version(platform=None) -> int | None:
|
||||
backend = get_backend(platform)
|
||||
return getattr(backend, "xla_version", None)
|
||||
|
||||
def backend_stablehlo_version(platform=None) -> int | None:
|
||||
def backend_stablehlo_version(platform=None) -> Sequence[int] | None:
|
||||
"""Returns the StableHLO version of the backend.
|
||||
|
||||
Returns None if the backend does not use PJRT C API or does not have
|
||||
|
Loading…
x
Reference in New Issue
Block a user