mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove support for classic HLO computations in compilation cache.
These are never used except in this unit test any more; we always use MLIR. PiperOrigin-RevId: 507473543
This commit is contained in:
parent
077ff29729
commit
fbbd442db7
@ -124,12 +124,12 @@ def _hash_computation(hash_obj, xla_computation):
|
||||
elif isinstance(xla_computation, str):
|
||||
serialized_hlo = xla_computation.encode() # MLIR module text
|
||||
else:
|
||||
serialized_hlo = xla_computation.as_serialized_hlo_module_proto()
|
||||
raise TypeError(f"Unknown computation type {type(xla_computation)}")
|
||||
scrubbed_hlo = re.sub(b" at 0x[a-f0-9]+>", b" at 0x...>", serialized_hlo)
|
||||
hash_obj.update(scrubbed_hlo)
|
||||
|
||||
def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
expected_num_compile_options = 37 if xla.xc._version >= 114 else 35
|
||||
expected_num_compile_options = 37 if xla_extension_version >= 114 else 35
|
||||
assert len(dir(compile_options_obj)) == expected_num_compile_options, (
|
||||
f"Unexpected number of CompileOption fields: "
|
||||
f"{len(dir(compile_options_obj))}. This likely: means that an extra "
|
||||
|
@ -131,7 +131,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(hash1, hash2)
|
||||
|
||||
def test_same_hash_key(self):
|
||||
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
computation = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = xla_bridge.get_backend()
|
||||
@ -139,7 +139,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
cc.get_cache_key(computation, compile_options, backend))
|
||||
|
||||
def test_different_hash_key(self):
|
||||
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
computation = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
|
||||
compile_options_not_filled = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
compile_options_filled = self.filled_compile_options()
|
||||
@ -148,8 +148,8 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
cc.get_cache_key(computation, compile_options_filled, backend))
|
||||
|
||||
def test_different_computations(self):
|
||||
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
|
||||
computation1 = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
|
||||
computation2 = str(jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir())
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = xla_bridge.get_backend()
|
||||
@ -160,7 +160,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
if jtu.is_device_tpu_v4():
|
||||
raise unittest.SkipTest("TODO(b/240151176)")
|
||||
|
||||
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
computation = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = xla_bridge.get_backend()
|
||||
@ -202,7 +202,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
def test_get_no_executable(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
computation = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = xla_bridge.get_backend()
|
||||
|
Loading…
x
Reference in New Issue
Block a user