Change compilation_cache_test to compile MHLO instead of classic HLO.

Support for classic HLO is being dropped from the .compile() API.

In passing, also remove some obsolete version checks. The minimum xla_client API version is currently 109.

PiperOrigin-RevId: 496708463
This commit is contained in:
Peter Hawkins 2022-12-20 11:25:42 -08:00 committed by jax authors
parent 843bc43790
commit 357d044965
2 changed files with 10 additions and 7 deletions

View File

@ -22,7 +22,6 @@ from typing import List, Optional
from jax.experimental.compilation_cache.gfile_cache import GFileCache
from jax._src import path as pathlib
from jax._src.lib import version_str as jaxlib_version_str
from jax._src.lib import xla_client
from jax.interpreters import xla
_cache = None
@ -129,8 +128,7 @@ def _hash_computation(hash_obj, xla_computation):
hash_obj.update(scrubbed_hlo)
def _hash_compile_options(hash_obj, compile_options_obj):
# TODO(parkers): simplify this code when jaxlib >= 0.3.23 is the minimum.
expected_num_compile_options = 35 if xla_client._version >= 104 else 33
expected_num_compile_options = 35
assert len(dir(compile_options_obj)) == expected_num_compile_options, (
f"Unexpected number of CompileOption fields: "
f"{len(dir(compile_options_obj))}. This likely: means that an extra "

View File

@ -207,8 +207,12 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_diff_executables(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
computation1 = str(jax.jit(lambda x, y: x + y)
.lower(1, 1)
.compiler_ir(dialect="mhlo"))
computation2 = str(jax.jit(lambda x, y: x * y)
.lower(2, 2)
.compiler_ir(dialect="mhlo"))
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = xla_bridge.get_backend()
@ -224,8 +228,9 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_put_executable(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
computation = jax.xla_computation(lambda x, y: x + y)(np.int32(1),
np.int32(1))
computation = str(jax.jit(lambda x, y: x + y)
.lower(np.int32(1), np.int32(1))
.compiler_ir(dialect="mhlo"))
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = xla_bridge.get_backend()