Hash serialized topology description for new cache key generation.

The original cache key generation hashes devices and backend. This
is not future proof: it does not work for accelerators other than
TPUs. Change this to use the serialized version of
PjRtTopologyDescription which is supported for all accelerators.

Note:
. CPU and PjRt C API not supported as yet.
. Stream Executor will not be supported.

Testing: revised unit test.
PiperOrigin-RevId: 564461564
This commit is contained in:
jax authors 2023-09-11 12:07:48 -07:00
parent a36598b2a7
commit 23e4f0b471
3 changed files with 52 additions and 10 deletions

View File

@ -86,11 +86,9 @@ def get(module: ir.Module,
"""
entries = [
("computation", lambda hash_obj: _hash_computation(hash_obj, module)),
("devices", lambda hash_obj: _hash_devices(hash_obj, devices)),
("jax_lib version",
lambda hash_obj: hash_obj.update(
bytes(jaxlib_version_str.encode("utf-8")))),
("the backend", lambda hash_obj: _hash_platform(hash_obj, backend)),
("XLA flags",
lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())),
("compression",
@ -101,12 +99,24 @@ def get(module: ir.Module,
("compile_options",
lambda hash_obj: _hash_compile_options(hash_obj, compile_options)),
)
entries.append(
("devices", lambda hash_obj: _hash_devices(hash_obj, devices)))
entries.append(
("the backend", lambda hash_obj: _hash_platform(hash_obj, backend)),
)
else:
assert (
xla_extension_version >= 193
), "new cache key generation requires jaxlib 0.4.15 or newer"
entries.append(
("compile_options",
lambda hash_obj: _hash_serialized_compile_options(
hash_obj, compile_options)),
)
entries.append(
("accelerator_config",
lambda hash_obj: _hash_accelerator_config(hash_obj, devices)),
)
hash_obj = hashlib.sha256()
for name, hashfn in entries:
@ -167,11 +177,16 @@ def _hash_devices(hash_obj, devices: np.ndarray) -> None:
_hash_string(hash_obj, device.device_kind)
def _hash_serialized_compile_options(hash_obj, compile_options_obj):
assert (
xla_extension_version >= 193
), "new cache key generation requires jaxlib 0.4.15 or newer"
def _hash_accelerator_config(hash_obj, accelerators: np.ndarray):
accelerator_devices = []
for accelerator in accelerators.flat:
accelerator_devices.append(accelerator)
hash_obj.update(
xla_client.get_topology_for_devices(accelerator_devices).serialize()
)
def _hash_serialized_compile_options(hash_obj, compile_options_obj):
# Do not mess with the original CompileOptions object since it is passed to
# the compiler. Create a deep copy for the purpose of cache key generation.
compile_options_copy = copy.deepcopy(compile_options_obj)

View File

@ -265,10 +265,16 @@ def compile_or_get_cached(
monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache')
cache_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend,
jax_config.config.jax_use_original_compilation_cache_key_generation,
)
try:
cache_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend,
jax_config.config.jax_use_original_compilation_cache_key_generation,
)
except xc._xla.XlaRuntimeError as ex:
logger.error("compile_or_get_cached: unable to generate cache key, "
"skipping the cache: %s", ex)
return backend_compile(backend, computation, compile_options,
host_callbacks)
cache_retrieval_start = time.monotonic()
retrieved_executable, retrieved_compile_time = _cache_read(

View File

@ -131,6 +131,27 @@ class CacheKeyTest(jtu.JaxTestCase):
)
self.assertEqual(hash1, hash2)
@unittest.skipIf(
xla_extension_version < 193, "Test requires jaxlib 0.4.15 or newer"
)
@jtu.skip_on_devices("cpu")
def test_hash_accelerator_devices(self):
if jtu.is_se_tpu():
raise unittest.SkipTest("StreamExecutor not supported.")
if xla_bridge.using_pjrt_c_api():
# TODO(b/290248051): expose PjRtTopologyDesc in PjRt C API.
raise unittest.SkipTest("PjRt C API not yet supported.")
devices = np.array([[jax.local_devices()[0]]])
dev_hash1 = self.get_hashed_value(cache_key._hash_devices, devices)
dev_hash2 = self.get_hashed_value(cache_key._hash_devices, devices)
self.assertEqual(dev_hash1, dev_hash2)
acc_hash1 = self.get_hashed_value(cache_key._hash_accelerator_config, devices)
acc_hash2 = self.get_hashed_value(cache_key._hash_accelerator_config, devices)
self.assertEqual(acc_hash1, acc_hash2)
def test_hash_platform(self):
hash1 = self.get_hashed_value(
cache_key._hash_platform, xla_bridge.get_backend()