Remove support for classic HLO computations in compilation cache.

These are never used except in this unit test any more; we always use MLIR.

PiperOrigin-RevId: 507473543
This commit is contained in:
Peter Hawkins 2023-02-06 07:24:09 -08:00 committed by jax authors
parent 077ff29729
commit fbbd442db7
2 changed files with 8 additions and 8 deletions

View File

@ -124,12 +124,12 @@ def _hash_computation(hash_obj, xla_computation):
elif isinstance(xla_computation, str):
serialized_hlo = xla_computation.encode() # MLIR module text
else:
serialized_hlo = xla_computation.as_serialized_hlo_module_proto()
raise TypeError(f"Unknown computation type {type(xla_computation)}")
scrubbed_hlo = re.sub(b" at 0x[a-f0-9]+>", b" at 0x...>", serialized_hlo)
hash_obj.update(scrubbed_hlo)
def _hash_compile_options(hash_obj, compile_options_obj):
expected_num_compile_options = 37 if xla.xc._version >= 114 else 35
expected_num_compile_options = 37 if xla_extension_version >= 114 else 35
assert len(dir(compile_options_obj)) == expected_num_compile_options, (
f"Unexpected number of CompileOption fields: "
f"{len(dir(compile_options_obj))}. This likely: means that an extra "

View File

@ -131,7 +131,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertNotEqual(hash1, hash2)
def test_same_hash_key(self):
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
computation = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = xla_bridge.get_backend()
@ -139,7 +139,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
cc.get_cache_key(computation, compile_options, backend))
def test_different_hash_key(self):
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
computation = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
compile_options_not_filled = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
compile_options_filled = self.filled_compile_options()
@ -148,8 +148,8 @@ class CompilationCacheTest(jtu.JaxTestCase):
cc.get_cache_key(computation, compile_options_filled, backend))
def test_different_computations(self):
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
computation1 = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
computation2 = str(jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir())
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = xla_bridge.get_backend()
@ -160,7 +160,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
if jtu.is_device_tpu_v4():
raise unittest.SkipTest("TODO(b/240151176)")
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
computation = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = xla_bridge.get_backend()
@ -202,7 +202,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_get_no_executable(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
computation = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = xla_bridge.get_backend()