mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00
Change compilation_cache_test to compile MHLO instead of classic HLO.
Support for classic HLO is being dropped from the .compile() API. In passing, also remove some obsolete version checks. The minimum xla_client API version is currently 109. PiperOrigin-RevId: 496708463
This commit is contained in:
parent
843bc43790
commit
357d044965
@ -22,7 +22,6 @@ from typing import List, Optional
|
||||
from jax.experimental.compilation_cache.gfile_cache import GFileCache
|
||||
from jax._src import path as pathlib
|
||||
from jax._src.lib import version_str as jaxlib_version_str
|
||||
from jax._src.lib import xla_client
|
||||
from jax.interpreters import xla
|
||||
|
||||
_cache = None
|
||||
@ -129,8 +128,7 @@ def _hash_computation(hash_obj, xla_computation):
|
||||
hash_obj.update(scrubbed_hlo)
|
||||
|
||||
def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
# TODO(parkers): simplify this code when jaxlib >= 0.3.23 is the minimum.
|
||||
expected_num_compile_options = 35 if xla_client._version >= 104 else 33
|
||||
expected_num_compile_options = 35
|
||||
assert len(dir(compile_options_obj)) == expected_num_compile_options, (
|
||||
f"Unexpected number of CompileOption fields: "
|
||||
f"{len(dir(compile_options_obj))}. This likely: means that an extra "
|
||||
|
@ -207,8 +207,12 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
def test_diff_executables(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
|
||||
computation1 = str(jax.jit(lambda x, y: x + y)
|
||||
.lower(1, 1)
|
||||
.compiler_ir(dialect="mhlo"))
|
||||
computation2 = str(jax.jit(lambda x, y: x * y)
|
||||
.lower(2, 2)
|
||||
.compiler_ir(dialect="mhlo"))
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = xla_bridge.get_backend()
|
||||
@ -224,8 +228,9 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
def test_put_executable(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
computation = jax.xla_computation(lambda x, y: x + y)(np.int32(1),
|
||||
np.int32(1))
|
||||
computation = str(jax.jit(lambda x, y: x + y)
|
||||
.lower(np.int32(1), np.int32(1))
|
||||
.compiler_ir(dialect="mhlo"))
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = xla_bridge.get_backend()
|
||||
|
Loading…
x
Reference in New Issue
Block a user