mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Hash serialized topology description for new cache key generation.
The original cache key generation hashes devices and backend. This is not future proof: it does not work for accelerators other than TPUs. Change this to use the serialized version of PjRtTopologyDescription which is supported for all accelerators. Note: . CPU and PjRt C API not supported as yet. . Stream Executor will not be supported. Testing: revised unit test. PiperOrigin-RevId: 564461564
This commit is contained in:
parent
a36598b2a7
commit
23e4f0b471
@ -86,11 +86,9 @@ def get(module: ir.Module,
|
||||
"""
|
||||
entries = [
|
||||
("computation", lambda hash_obj: _hash_computation(hash_obj, module)),
|
||||
("devices", lambda hash_obj: _hash_devices(hash_obj, devices)),
|
||||
("jax_lib version",
|
||||
lambda hash_obj: hash_obj.update(
|
||||
bytes(jaxlib_version_str.encode("utf-8")))),
|
||||
("the backend", lambda hash_obj: _hash_platform(hash_obj, backend)),
|
||||
("XLA flags",
|
||||
lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())),
|
||||
("compression",
|
||||
@ -101,12 +99,24 @@ def get(module: ir.Module,
|
||||
("compile_options",
|
||||
lambda hash_obj: _hash_compile_options(hash_obj, compile_options)),
|
||||
)
|
||||
entries.append(
|
||||
("devices", lambda hash_obj: _hash_devices(hash_obj, devices)))
|
||||
entries.append(
|
||||
("the backend", lambda hash_obj: _hash_platform(hash_obj, backend)),
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
xla_extension_version >= 193
|
||||
), "new cache key generation requires jaxlib 0.4.15 or newer"
|
||||
entries.append(
|
||||
("compile_options",
|
||||
lambda hash_obj: _hash_serialized_compile_options(
|
||||
hash_obj, compile_options)),
|
||||
)
|
||||
entries.append(
|
||||
("accelerator_config",
|
||||
lambda hash_obj: _hash_accelerator_config(hash_obj, devices)),
|
||||
)
|
||||
|
||||
hash_obj = hashlib.sha256()
|
||||
for name, hashfn in entries:
|
||||
@ -167,11 +177,16 @@ def _hash_devices(hash_obj, devices: np.ndarray) -> None:
|
||||
_hash_string(hash_obj, device.device_kind)
|
||||
|
||||
|
||||
def _hash_serialized_compile_options(hash_obj, compile_options_obj):
|
||||
assert (
|
||||
xla_extension_version >= 193
|
||||
), "new cache key generation requires jaxlib 0.4.15 or newer"
|
||||
def _hash_accelerator_config(hash_obj, accelerators: np.ndarray):
|
||||
accelerator_devices = []
|
||||
for accelerator in accelerators.flat:
|
||||
accelerator_devices.append(accelerator)
|
||||
hash_obj.update(
|
||||
xla_client.get_topology_for_devices(accelerator_devices).serialize()
|
||||
)
|
||||
|
||||
|
||||
def _hash_serialized_compile_options(hash_obj, compile_options_obj):
|
||||
# Do not mess with the original CompileOptions object since it is passed to
|
||||
# the compiler. Create a deep copy for the purpose of cache key generation.
|
||||
compile_options_copy = copy.deepcopy(compile_options_obj)
|
||||
|
@ -265,10 +265,16 @@ def compile_or_get_cached(
|
||||
|
||||
monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache')
|
||||
|
||||
cache_key = compilation_cache.get_cache_key(
|
||||
computation, devices, compile_options, backend,
|
||||
jax_config.config.jax_use_original_compilation_cache_key_generation,
|
||||
)
|
||||
try:
|
||||
cache_key = compilation_cache.get_cache_key(
|
||||
computation, devices, compile_options, backend,
|
||||
jax_config.config.jax_use_original_compilation_cache_key_generation,
|
||||
)
|
||||
except xc._xla.XlaRuntimeError as ex:
|
||||
logger.error("compile_or_get_cached: unable to generate cache key, "
|
||||
"skipping the cache: %s", ex)
|
||||
return backend_compile(backend, computation, compile_options,
|
||||
host_callbacks)
|
||||
|
||||
cache_retrieval_start = time.monotonic()
|
||||
retrieved_executable, retrieved_compile_time = _cache_read(
|
||||
|
@ -131,6 +131,27 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
)
|
||||
self.assertEqual(hash1, hash2)
|
||||
|
||||
@unittest.skipIf(
|
||||
xla_extension_version < 193, "Test requires jaxlib 0.4.15 or newer"
|
||||
)
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def test_hash_accelerator_devices(self):
|
||||
if jtu.is_se_tpu():
|
||||
raise unittest.SkipTest("StreamExecutor not supported.")
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
# TODO(b/290248051): expose PjRtTopologyDesc in PjRt C API.
|
||||
raise unittest.SkipTest("PjRt C API not yet supported.")
|
||||
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
|
||||
dev_hash1 = self.get_hashed_value(cache_key._hash_devices, devices)
|
||||
dev_hash2 = self.get_hashed_value(cache_key._hash_devices, devices)
|
||||
self.assertEqual(dev_hash1, dev_hash2)
|
||||
|
||||
acc_hash1 = self.get_hashed_value(cache_key._hash_accelerator_config, devices)
|
||||
acc_hash2 = self.get_hashed_value(cache_key._hash_accelerator_config, devices)
|
||||
self.assertEqual(acc_hash1, acc_hash2)
|
||||
|
||||
def test_hash_platform(self):
|
||||
hash1 = self.get_hashed_value(
|
||||
cache_key._hash_platform, xla_bridge.get_backend()
|
||||
|
Loading…
x
Reference in New Issue
Block a user