From 23e4f0b471122a07a8cd77d2ce61e4c6c10bf3ca Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 11 Sep 2023 12:07:48 -0700 Subject: [PATCH] Hash serialized topology description for new cache key generation. The original cache key generation hashes devices and backend. This is not future proof: it does not work for accelerators other than TPUs. Change this to use the serialized version of PjRtTopologyDescription which is supported for all accelerators. Note: . CPU and PjRt C API not supported as yet. . Stream Executor will not be supported. Testing: revised unit test. PiperOrigin-RevId: 564461564 --- jax/_src/cache_key.py | 27 +++++++++++++++++++++------ jax/_src/compiler.py | 14 ++++++++++---- tests/cache_key_test.py | 21 +++++++++++++++++++++ 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index cfc8292ff..9ce4bddb2 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -86,11 +86,9 @@ def get(module: ir.Module, """ entries = [ ("computation", lambda hash_obj: _hash_computation(hash_obj, module)), - ("devices", lambda hash_obj: _hash_devices(hash_obj, devices)), ("jax_lib version", lambda hash_obj: hash_obj.update( bytes(jaxlib_version_str.encode("utf-8")))), - ("the backend", lambda hash_obj: _hash_platform(hash_obj, backend)), ("XLA flags", lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())), ("compression", @@ -101,12 +99,24 @@ def get(module: ir.Module, ("compile_options", lambda hash_obj: _hash_compile_options(hash_obj, compile_options)), ) + entries.append( + ("devices", lambda hash_obj: _hash_devices(hash_obj, devices))) + entries.append( + ("the backend", lambda hash_obj: _hash_platform(hash_obj, backend)), + ) else: + assert ( + xla_extension_version >= 193 + ), "new cache key generation requires jaxlib 0.4.15 or newer" entries.append( ("compile_options", lambda hash_obj: _hash_serialized_compile_options( hash_obj, compile_options)), ) + entries.append( + ("accelerator_config", + lambda hash_obj: _hash_accelerator_config(hash_obj, devices)), + ) hash_obj = hashlib.sha256() for name, hashfn in entries: @@ -167,11 +177,16 @@ def _hash_devices(hash_obj, devices: np.ndarray) -> None: _hash_string(hash_obj, device.device_kind) -def _hash_serialized_compile_options(hash_obj, compile_options_obj): - assert ( - xla_extension_version >= 193 - ), "new cache key generation requires jaxlib 0.4.15 or newer" +def _hash_accelerator_config(hash_obj, accelerators: np.ndarray): + accelerator_devices = [] + for accelerator in accelerators.flat: + accelerator_devices.append(accelerator) + hash_obj.update( + xla_client.get_topology_for_devices(accelerator_devices).serialize() + ) + +def _hash_serialized_compile_options(hash_obj, compile_options_obj): # Do not mess with the original CompileOptions object since it is passed to # the compiler. Create a deep copy for the purpose of cache key generation. compile_options_copy = copy.deepcopy(compile_options_obj) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 1750648c5..21b2f2c99 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -265,10 +265,16 @@ def compile_or_get_cached( monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache') - cache_key = compilation_cache.get_cache_key( - computation, devices, compile_options, backend, - jax_config.config.jax_use_original_compilation_cache_key_generation, - ) + try: + cache_key = compilation_cache.get_cache_key( + computation, devices, compile_options, backend, + jax_config.config.jax_use_original_compilation_cache_key_generation, + ) + except xc._xla.XlaRuntimeError as ex: + logger.error("compile_or_get_cached: unable to generate cache key, " + "skipping the cache: %s", ex) + return backend_compile(backend, computation, compile_options, + host_callbacks) cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index bcb170bea..a3f904b27 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -131,6 +131,27 @@ class CacheKeyTest(jtu.JaxTestCase): ) self.assertEqual(hash1, hash2) + @unittest.skipIf( + xla_extension_version < 193, "Test requires jaxlib 0.4.15 or newer" + ) + @jtu.skip_on_devices("cpu") + def test_hash_accelerator_devices(self): + if jtu.is_se_tpu(): + raise unittest.SkipTest("StreamExecutor not supported.") + if xla_bridge.using_pjrt_c_api(): + # TODO(b/290248051): expose PjRtTopologyDesc in PjRt C API. + raise unittest.SkipTest("PjRt C API not yet supported.") + + devices = np.array([[jax.local_devices()[0]]]) + + dev_hash1 = self.get_hashed_value(cache_key._hash_devices, devices) + dev_hash2 = self.get_hashed_value(cache_key._hash_devices, devices) + self.assertEqual(dev_hash1, dev_hash2) + + acc_hash1 = self.get_hashed_value(cache_key._hash_accelerator_config, devices) + acc_hash2 = self.get_hashed_value(cache_key._hash_accelerator_config, devices) + self.assertEqual(acc_hash1, acc_hash2) + def test_hash_platform(self): hash1 = self.get_hashed_value( cache_key._hash_platform, xla_bridge.get_backend()