diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 43bb0fb06..13f3e2933 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -51,6 +51,15 @@ def get_flag_prefixes() -> list[str]: return _extra_flag_prefixes +def custom_hook() -> str: + """Custom hook for any addition to the cache key. + + The custom hook will be called everytime get() is called and can be + defined to return a string that will be hashed into the cache key. + """ + return "" + + def get(module: ir.Module, devices: np.ndarray, compile_options: xla_client.CompileOptions, @@ -86,6 +95,7 @@ def get(module: ir.Module, lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)), ("compression", lambda hash_obj: _hash_string(hash_obj, compression_algorithm)), + ("custom_hook", lambda hash_obj: _hash_string(hash_obj, custom_hook())), ] hash_obj = hashlib.sha256() diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 8d04aeceb..33fc02009 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -127,6 +127,21 @@ class CacheKeyTest(jtu.JaxTestCase): cache_key.get(computation, devices, compile_options_filled, backend), ) + def test_custom_hook(self): + computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() + devices = np.array([[jax.local_devices()[0]]]) + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + original_custom_hook = cache_key.custom_hook + cache_key.custom_hook = lambda: "hook1" + key1 = cache_key.get(computation, devices, compile_options, backend) + cache_key.custom_hook = lambda: "hook2" + key2 = cache_key.get(computation, devices, compile_options, backend) + cache_key.custom_hook = original_custom_hook + self.assertNotEqual(key1, key2) + def test_different_computations(self): computation1 = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() computation2 = jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()