From f939048e629f53810ec7ba7e0854a4c96177776f Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 4 Oct 2021 17:22:14 -0700 Subject: [PATCH] Include XLA_FLAGS in persistent compilation cache key. This is to prevent false cache hits when the compiler behavior is changed via flags. Flags known to not affect the compiled executable (e.g. dumping HLO) are excluded from the key. Note that any XLA flags with arguments should use = and not a space, e.g. `--xla_flag=value`, not `--xla_flag value`. I believe this is already a requirement of ABSL flags in general, but I'm not 100% sure. Also note that this doesn't currently support XLA flags specified via --flagfile. Please file a feature request if this is needed. --- .../compilation_cache/compilation_cache.py | 62 ++++++++++++++++--- tests/compilation_cache_test.py | 41 ++++++++++++ 2 files changed, 95 insertions(+), 8 deletions(-) diff --git a/jax/experimental/compilation_cache/compilation_cache.py b/jax/experimental/compilation_cache/compilation_cache.py index 77a423f20..e2650bf85 100644 --- a/jax/experimental/compilation_cache/compilation_cache.py +++ b/jax/experimental/compilation_cache/compilation_cache.py @@ -13,7 +13,9 @@ # limitations under the License. import hashlib +import os import re +import sys import jax from jax.experimental.compilation_cache.file_system_cache import FileSystemCache @@ -57,6 +59,11 @@ def put_executable(xla_computation, compile_options, executable: xla_client.Exec serialized_executable = backend.serialize_executable(executable) _cache.put(cache_key, serialized_executable) +def _log_cache_key_hash(hash_obj, last_serialized: str): + if logging.vlog_is_on(1): + logging.vlog(1, "get_cache_key hash after serializing %s: %s", + last_serialized, hash_obj.digest().hex()) + def get_cache_key(xla_computation, compile_options, backend) -> str: """Creates a hashed string to use as a key to the compilation cache. @@ -79,17 +86,20 @@ def get_cache_key(xla_computation, compile_options, backend) -> str: serialized_hlo = xla_computation.as_serialized_hlo_module_proto() scrubbed_hlo = re.sub(b" at 0x[a-f0-9]+>", b" at 0x...>", serialized_hlo) hash_obj.update(scrubbed_hlo) - if logging.vlog_is_on(1): - logging.vlog(1, f"get_cache_key hash after serializing computation: {hash_obj.digest().hex()}") + _log_cache_key_hash(hash_obj, "computation") + _hash_compile_options(hash_obj, compile_options) - if logging.vlog_is_on(1): - logging.vlog(1, f"get_cache_key hash after serializing compile_options: {hash_obj.digest().hex()}") + _log_cache_key_hash(hash_obj, "compile_options") + hash_obj.update(bytes(jax._src.lib.version)) - if logging.vlog_is_on(1): - logging.vlog(1, f"get_cache_key hash after serializing jax_lib version: {hash_obj.digest().hex()}") + _log_cache_key_hash(hash_obj, "jax_lib version") + _hash_platform(hash_obj, backend) - if logging.vlog_is_on(1): - logging.vlog(1, f"get_cache_key hash after serializing the backend: {hash_obj.digest().hex()}") + _log_cache_key_hash(hash_obj, "the backend") + + _hash_xla_flags(hash_obj) + _log_cache_key_hash(hash_obj, "XLA flags") + return hash_obj.digest().hex() def _hash_compile_options(hash_obj, compile_options_obj): @@ -145,6 +155,42 @@ def _hash_platform(hash_obj, backend): _hash_string(hash_obj, backend.platform_version) _hash_string(hash_obj, backend.runtime_type) +_xla_flags_to_exclude_from_cache_key = [ + "--xla_dump_compress_protos", + "--xla_dump_module_metadata", + "--xla_dump_max_hlo_modules", + "--xla_dump_include_timestamp", + "--xla_dump_hlo_pass_re", + "--xla_dump_hlo_module_re", + "--xla_dump_hlo_snapshots", + "--xla_dump_fusion_visualization", + "--xla_dump_hlo_as_url", + "--xla_dump_hlo_as_proto", + "--xla_dump_hlo_as_text", + "--xla_dump_to", + "--xla_force_host_platform_device_count", + "--xla_dump_disable_metadata", + "--xla_dump_hlo_pipeline_re", +] + +def _hash_xla_flags(hash_obj): + xla_flags = [] + + xla_flags_env_var = os.getenv("XLA_FLAGS") + if xla_flags_env_var: + xla_flags.extend(xla_flags_env_var.split()) + + xla_flags.extend(arg for arg in sys.argv if arg.startswith("--xla_")) + + # N.B. all XLA flags that take an argument must use '=' and not a space + # (e.g. --xla_force_host_platform_device_count=8) (I think). + for flag in xla_flags: + if flag.split('=')[0] in _xla_flags_to_exclude_from_cache_key: + logging.vlog(1, "Not including XLA flag in cache key: %s", flag) + continue + logging.vlog(1, "Including XLA flag in cache key: %s", flag) + _hash_string(hash_obj, flag) + def _hash_int(hash_obj, int_var): hash_obj.update(int_var.to_bytes(8, byteorder='big')) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 3c98240b5..44e7d98c5 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -16,6 +16,7 @@ from functools import partial import hashlib import os import random +import sys import tempfile import unittest from unittest import SkipTest @@ -141,6 +142,46 @@ class CompilationCacheTest(jtu.JaxTestCase): self.assertNotEqual(cc.get_cache_key(computation1, compile_options, backend), cc.get_cache_key(computation2, compile_options, backend)) + def test_xla_flags(self): + computation = jax.xla_computation(lambda x, y: x + y)(1, 1) + compile_options = jax._src.lib.xla_bridge.get_compile_options( + num_replicas=1, num_partitions=1) + backend = jax._src.lib.xla_bridge.get_backend() + + orig_xla_flags = os.getenv("XLA_FLAGS") + orig_argv = sys.argv + try: + os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0" + key1 = cc.get_cache_key(computation, compile_options, backend) + os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1" + key2 = cc.get_cache_key(computation, compile_options, backend) + self.assertNotEqual(key1, key2) + + os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0" + key3 = cc.get_cache_key(computation, compile_options, backend) + self.assertEqual(key1, key3) + + # Test flag in _xla_flags_to_exclude_from_cache_key + os.environ["XLA_FLAGS"] = ( + "--xla_gpu_autotune_level=0 --xla_force_host_platform_device_count=8") + key4 = cc.get_cache_key(computation, compile_options, backend) + self.assertEqual(key1, key4) + + # Test flags given on command line + del os.environ["XLA_FLAGS"] + sys.argv.append("--xla_gpu_autotune_level=0") + key5 = cc.get_cache_key(computation, compile_options, backend) + self.assertEqual(key1, key5) + sys.argv.append("--xla_force_host_platform_device_count=8") + self.assertEqual(key1, key5) + + finally: + if orig_xla_flags is not None: + os.environ["XLA_FLAGS"] = orig_xla_flags + elif os.getenv("XLA_FLAGS") is not None: + del os.environ["XLA_FLAGS"] + sys.argv = orig_argv + def test_get_no_executable(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir)