From 596756f71591f820311419f1e663f6572c2139f8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 26 Feb 2024 18:17:52 -0800 Subject: [PATCH] Enhance compilation cache key generation with a custom hook. The custom hook is called every time the cache key is generated. It can be programmed to add a custom string that is then hashed as part of the cache key. Testing: test workloads. PiperOrigin-RevId: 610586945 --- jax/_src/cache_key.py | 10 ++++++++++ tests/cache_key_test.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 43bb0fb06..13f3e2933 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -51,6 +51,15 @@ def get_flag_prefixes() -> list[str]: return _extra_flag_prefixes +def custom_hook() -> str: + """Custom hook for any addition to the cache key. + + The custom hook will be called everytime get() is called and can be + defined to return a string that will be hashed into the cache key. + """ + return "" + + def get(module: ir.Module, devices: np.ndarray, compile_options: xla_client.CompileOptions, @@ -86,6 +95,7 @@ def get(module: ir.Module, lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)), ("compression", lambda hash_obj: _hash_string(hash_obj, compression_algorithm)), + ("custom_hook", lambda hash_obj: _hash_string(hash_obj, custom_hook())), ] hash_obj = hashlib.sha256() diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 8d04aeceb..33fc02009 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -127,6 +127,21 @@ class CacheKeyTest(jtu.JaxTestCase): cache_key.get(computation, devices, compile_options_filled, backend), ) + def test_custom_hook(self): + computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() + devices = np.array([[jax.local_devices()[0]]]) + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + original_custom_hook = cache_key.custom_hook + cache_key.custom_hook = lambda: "hook1" + key1 = cache_key.get(computation, devices, compile_options, backend) + cache_key.custom_hook = lambda: "hook2" + key2 = cache_key.get(computation, devices, compile_options, backend) + cache_key.custom_hook = original_custom_hook + self.assertNotEqual(key1, key2) + def test_different_computations(self): computation1 = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() computation2 = jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()