Enhance compilation cache key generation with a custom hook.

The custom hook is called every time the cache key is
generated. It can be programmed to add a custom string that
is then hashed as part of the cache key.

Testing: test workloads.
PiperOrigin-RevId: 610586945
This commit is contained in:
jax authors 2024-02-26 18:17:52 -08:00
parent fab8f6cfdd
commit 596756f715
2 changed files with 25 additions and 0 deletions

View File

@ -51,6 +51,15 @@ def get_flag_prefixes() -> list[str]:
return _extra_flag_prefixes
def custom_hook() -> str:
"""Custom hook for any addition to the cache key.
The custom hook will be called everytime get() is called and can be
defined to return a string that will be hashed into the cache key.
"""
return ""
def get(module: ir.Module,
devices: np.ndarray,
compile_options: xla_client.CompileOptions,
@ -86,6 +95,7 @@ def get(module: ir.Module,
lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)),
("compression",
lambda hash_obj: _hash_string(hash_obj, compression_algorithm)),
("custom_hook", lambda hash_obj: _hash_string(hash_obj, custom_hook())),
]
hash_obj = hashlib.sha256()

View File

@ -127,6 +127,21 @@ class CacheKeyTest(jtu.JaxTestCase):
cache_key.get(computation, devices, compile_options_filled, backend),
)
def test_custom_hook(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
original_custom_hook = cache_key.custom_hook
cache_key.custom_hook = lambda: "hook1"
key1 = cache_key.get(computation, devices, compile_options, backend)
cache_key.custom_hook = lambda: "hook2"
key2 = cache_key.get(computation, devices, compile_options, backend)
cache_key.custom_hook = original_custom_hook
self.assertNotEqual(key1, key2)
def test_different_computations(self):
computation1 = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
computation2 = jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()