mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Enhance compilation cache key generation with a custom hook.
The custom hook is called every time the cache key is generated. It can be programmed to add a custom string that is then hashed as part of the cache key. Testing: test workloads. PiperOrigin-RevId: 610586945
This commit is contained in:
parent
fab8f6cfdd
commit
596756f715
@ -51,6 +51,15 @@ def get_flag_prefixes() -> list[str]:
|
||||
return _extra_flag_prefixes
|
||||
|
||||
|
||||
def custom_hook() -> str:
|
||||
"""Custom hook for any addition to the cache key.
|
||||
|
||||
The custom hook will be called everytime get() is called and can be
|
||||
defined to return a string that will be hashed into the cache key.
|
||||
"""
|
||||
return ""
|
||||
|
||||
|
||||
def get(module: ir.Module,
|
||||
devices: np.ndarray,
|
||||
compile_options: xla_client.CompileOptions,
|
||||
@ -86,6 +95,7 @@ def get(module: ir.Module,
|
||||
lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)),
|
||||
("compression",
|
||||
lambda hash_obj: _hash_string(hash_obj, compression_algorithm)),
|
||||
("custom_hook", lambda hash_obj: _hash_string(hash_obj, custom_hook())),
|
||||
]
|
||||
|
||||
hash_obj = hashlib.sha256()
|
||||
|
@ -127,6 +127,21 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
cache_key.get(computation, devices, compile_options_filled, backend),
|
||||
)
|
||||
|
||||
def test_custom_hook(self):
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
original_custom_hook = cache_key.custom_hook
|
||||
cache_key.custom_hook = lambda: "hook1"
|
||||
key1 = cache_key.get(computation, devices, compile_options, backend)
|
||||
cache_key.custom_hook = lambda: "hook2"
|
||||
key2 = cache_key.get(computation, devices, compile_options, backend)
|
||||
cache_key.custom_hook = original_custom_hook
|
||||
self.assertNotEqual(key1, key2)
|
||||
|
||||
def test_different_computations(self):
|
||||
computation1 = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
computation2 = jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()
|
||||
|
Loading…
x
Reference in New Issue
Block a user