Merge pull request #8090 from skye:compilation_cache_xla_flags

PiperOrigin-RevId: 401343120
This commit is contained in:
jax authors 2021-10-06 14:42:18 -07:00
commit 3c117fd6ed
2 changed files with 95 additions and 8 deletions

View File

@ -13,7 +13,9 @@
# limitations under the License.
import hashlib
import os
import re
import sys
import jax
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
@ -57,6 +59,11 @@ def put_executable(xla_computation, compile_options, executable: xla_client.Exec
serialized_executable = backend.serialize_executable(executable)
_cache.put(cache_key, serialized_executable)
def _log_cache_key_hash(hash_obj, last_serialized: str):
if logging.vlog_is_on(1):
logging.vlog(1, "get_cache_key hash after serializing %s: %s",
last_serialized, hash_obj.digest().hex())
def get_cache_key(xla_computation, compile_options, backend) -> str:
"""Creates a hashed string to use as a key to the compilation cache.
@ -79,17 +86,20 @@ def get_cache_key(xla_computation, compile_options, backend) -> str:
serialized_hlo = xla_computation.as_serialized_hlo_module_proto()
scrubbed_hlo = re.sub(b" at 0x[a-f0-9]+>", b" at 0x...>", serialized_hlo)
hash_obj.update(scrubbed_hlo)
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing computation: {hash_obj.digest().hex()}")
_log_cache_key_hash(hash_obj, "computation")
_hash_compile_options(hash_obj, compile_options)
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing compile_options: {hash_obj.digest().hex()}")
_log_cache_key_hash(hash_obj, "compile_options")
hash_obj.update(bytes(jax._src.lib.version))
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing jax_lib version: {hash_obj.digest().hex()}")
_log_cache_key_hash(hash_obj, "jax_lib version")
_hash_platform(hash_obj, backend)
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing the backend: {hash_obj.digest().hex()}")
_log_cache_key_hash(hash_obj, "the backend")
_hash_xla_flags(hash_obj)
_log_cache_key_hash(hash_obj, "XLA flags")
return hash_obj.digest().hex()
def _hash_compile_options(hash_obj, compile_options_obj):
@ -145,6 +155,42 @@ def _hash_platform(hash_obj, backend):
_hash_string(hash_obj, backend.platform_version)
_hash_string(hash_obj, backend.runtime_type)
_xla_flags_to_exclude_from_cache_key = [
"--xla_dump_compress_protos",
"--xla_dump_module_metadata",
"--xla_dump_max_hlo_modules",
"--xla_dump_include_timestamp",
"--xla_dump_hlo_pass_re",
"--xla_dump_hlo_module_re",
"--xla_dump_hlo_snapshots",
"--xla_dump_fusion_visualization",
"--xla_dump_hlo_as_url",
"--xla_dump_hlo_as_proto",
"--xla_dump_hlo_as_text",
"--xla_dump_to",
"--xla_force_host_platform_device_count",
"--xla_dump_disable_metadata",
"--xla_dump_hlo_pipeline_re",
]
def _hash_xla_flags(hash_obj):
xla_flags = []
xla_flags_env_var = os.getenv("XLA_FLAGS")
if xla_flags_env_var:
xla_flags.extend(xla_flags_env_var.split())
xla_flags.extend(arg for arg in sys.argv if arg.startswith("--xla_"))
# N.B. all XLA flags that take an argument must use '=' and not a space
# (e.g. --xla_force_host_platform_device_count=8) (I think).
for flag in xla_flags:
if flag.split('=')[0] in _xla_flags_to_exclude_from_cache_key:
logging.vlog(1, "Not including XLA flag in cache key: %s", flag)
continue
logging.vlog(1, "Including XLA flag in cache key: %s", flag)
_hash_string(hash_obj, flag)
def _hash_int(hash_obj, int_var):
hash_obj.update(int_var.to_bytes(8, byteorder='big'))

View File

@ -16,6 +16,7 @@ from functools import partial
import hashlib
import os
import random
import sys
import tempfile
import unittest
from unittest import SkipTest
@ -141,6 +142,46 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertNotEqual(cc.get_cache_key(computation1, compile_options, backend),
cc.get_cache_key(computation2, compile_options, backend))
def test_xla_flags(self):
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options = jax._src.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = jax._src.lib.xla_bridge.get_backend()
orig_xla_flags = os.getenv("XLA_FLAGS")
orig_argv = sys.argv
try:
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
key1 = cc.get_cache_key(computation, compile_options, backend)
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1"
key2 = cc.get_cache_key(computation, compile_options, backend)
self.assertNotEqual(key1, key2)
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
key3 = cc.get_cache_key(computation, compile_options, backend)
self.assertEqual(key1, key3)
# Test flag in _xla_flags_to_exclude_from_cache_key
os.environ["XLA_FLAGS"] = (
"--xla_gpu_autotune_level=0 --xla_force_host_platform_device_count=8")
key4 = cc.get_cache_key(computation, compile_options, backend)
self.assertEqual(key1, key4)
# Test flags given on command line
del os.environ["XLA_FLAGS"]
sys.argv.append("--xla_gpu_autotune_level=0")
key5 = cc.get_cache_key(computation, compile_options, backend)
self.assertEqual(key1, key5)
sys.argv.append("--xla_force_host_platform_device_count=8")
self.assertEqual(key1, key5)
finally:
if orig_xla_flags is not None:
os.environ["XLA_FLAGS"] = orig_xla_flags
elif os.getenv("XLA_FLAGS") is not None:
del os.environ["XLA_FLAGS"]
sys.argv = orig_argv
def test_get_no_executable(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)