mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #8090 from skye:compilation_cache_xla_flags
PiperOrigin-RevId: 401343120
This commit is contained in:
commit
3c117fd6ed
@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import jax
|
||||
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
|
||||
@ -57,6 +59,11 @@ def put_executable(xla_computation, compile_options, executable: xla_client.Exec
|
||||
serialized_executable = backend.serialize_executable(executable)
|
||||
_cache.put(cache_key, serialized_executable)
|
||||
|
||||
def _log_cache_key_hash(hash_obj, last_serialized: str):
|
||||
if logging.vlog_is_on(1):
|
||||
logging.vlog(1, "get_cache_key hash after serializing %s: %s",
|
||||
last_serialized, hash_obj.digest().hex())
|
||||
|
||||
def get_cache_key(xla_computation, compile_options, backend) -> str:
|
||||
"""Creates a hashed string to use as a key to the compilation cache.
|
||||
|
||||
@ -79,17 +86,20 @@ def get_cache_key(xla_computation, compile_options, backend) -> str:
|
||||
serialized_hlo = xla_computation.as_serialized_hlo_module_proto()
|
||||
scrubbed_hlo = re.sub(b" at 0x[a-f0-9]+>", b" at 0x...>", serialized_hlo)
|
||||
hash_obj.update(scrubbed_hlo)
|
||||
if logging.vlog_is_on(1):
|
||||
logging.vlog(1, f"get_cache_key hash after serializing computation: {hash_obj.digest().hex()}")
|
||||
_log_cache_key_hash(hash_obj, "computation")
|
||||
|
||||
_hash_compile_options(hash_obj, compile_options)
|
||||
if logging.vlog_is_on(1):
|
||||
logging.vlog(1, f"get_cache_key hash after serializing compile_options: {hash_obj.digest().hex()}")
|
||||
_log_cache_key_hash(hash_obj, "compile_options")
|
||||
|
||||
hash_obj.update(bytes(jax._src.lib.version))
|
||||
if logging.vlog_is_on(1):
|
||||
logging.vlog(1, f"get_cache_key hash after serializing jax_lib version: {hash_obj.digest().hex()}")
|
||||
_log_cache_key_hash(hash_obj, "jax_lib version")
|
||||
|
||||
_hash_platform(hash_obj, backend)
|
||||
if logging.vlog_is_on(1):
|
||||
logging.vlog(1, f"get_cache_key hash after serializing the backend: {hash_obj.digest().hex()}")
|
||||
_log_cache_key_hash(hash_obj, "the backend")
|
||||
|
||||
_hash_xla_flags(hash_obj)
|
||||
_log_cache_key_hash(hash_obj, "XLA flags")
|
||||
|
||||
return hash_obj.digest().hex()
|
||||
|
||||
def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
@ -145,6 +155,42 @@ def _hash_platform(hash_obj, backend):
|
||||
_hash_string(hash_obj, backend.platform_version)
|
||||
_hash_string(hash_obj, backend.runtime_type)
|
||||
|
||||
_xla_flags_to_exclude_from_cache_key = [
|
||||
"--xla_dump_compress_protos",
|
||||
"--xla_dump_module_metadata",
|
||||
"--xla_dump_max_hlo_modules",
|
||||
"--xla_dump_include_timestamp",
|
||||
"--xla_dump_hlo_pass_re",
|
||||
"--xla_dump_hlo_module_re",
|
||||
"--xla_dump_hlo_snapshots",
|
||||
"--xla_dump_fusion_visualization",
|
||||
"--xla_dump_hlo_as_url",
|
||||
"--xla_dump_hlo_as_proto",
|
||||
"--xla_dump_hlo_as_text",
|
||||
"--xla_dump_to",
|
||||
"--xla_force_host_platform_device_count",
|
||||
"--xla_dump_disable_metadata",
|
||||
"--xla_dump_hlo_pipeline_re",
|
||||
]
|
||||
|
||||
def _hash_xla_flags(hash_obj):
|
||||
xla_flags = []
|
||||
|
||||
xla_flags_env_var = os.getenv("XLA_FLAGS")
|
||||
if xla_flags_env_var:
|
||||
xla_flags.extend(xla_flags_env_var.split())
|
||||
|
||||
xla_flags.extend(arg for arg in sys.argv if arg.startswith("--xla_"))
|
||||
|
||||
# N.B. all XLA flags that take an argument must use '=' and not a space
|
||||
# (e.g. --xla_force_host_platform_device_count=8) (I think).
|
||||
for flag in xla_flags:
|
||||
if flag.split('=')[0] in _xla_flags_to_exclude_from_cache_key:
|
||||
logging.vlog(1, "Not including XLA flag in cache key: %s", flag)
|
||||
continue
|
||||
logging.vlog(1, "Including XLA flag in cache key: %s", flag)
|
||||
_hash_string(hash_obj, flag)
|
||||
|
||||
def _hash_int(hash_obj, int_var):
|
||||
hash_obj.update(int_var.to_bytes(8, byteorder='big'))
|
||||
|
||||
|
@ -16,6 +16,7 @@ from functools import partial
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import SkipTest
|
||||
@ -141,6 +142,46 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(cc.get_cache_key(computation1, compile_options, backend),
|
||||
cc.get_cache_key(computation2, compile_options, backend))
|
||||
|
||||
def test_xla_flags(self):
|
||||
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
compile_options = jax._src.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = jax._src.lib.xla_bridge.get_backend()
|
||||
|
||||
orig_xla_flags = os.getenv("XLA_FLAGS")
|
||||
orig_argv = sys.argv
|
||||
try:
|
||||
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
|
||||
key1 = cc.get_cache_key(computation, compile_options, backend)
|
||||
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1"
|
||||
key2 = cc.get_cache_key(computation, compile_options, backend)
|
||||
self.assertNotEqual(key1, key2)
|
||||
|
||||
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
|
||||
key3 = cc.get_cache_key(computation, compile_options, backend)
|
||||
self.assertEqual(key1, key3)
|
||||
|
||||
# Test flag in _xla_flags_to_exclude_from_cache_key
|
||||
os.environ["XLA_FLAGS"] = (
|
||||
"--xla_gpu_autotune_level=0 --xla_force_host_platform_device_count=8")
|
||||
key4 = cc.get_cache_key(computation, compile_options, backend)
|
||||
self.assertEqual(key1, key4)
|
||||
|
||||
# Test flags given on command line
|
||||
del os.environ["XLA_FLAGS"]
|
||||
sys.argv.append("--xla_gpu_autotune_level=0")
|
||||
key5 = cc.get_cache_key(computation, compile_options, backend)
|
||||
self.assertEqual(key1, key5)
|
||||
sys.argv.append("--xla_force_host_platform_device_count=8")
|
||||
self.assertEqual(key1, key5)
|
||||
|
||||
finally:
|
||||
if orig_xla_flags is not None:
|
||||
os.environ["XLA_FLAGS"] = orig_xla_flags
|
||||
elif os.getenv("XLA_FLAGS") is not None:
|
||||
del os.environ["XLA_FLAGS"]
|
||||
sys.argv = orig_argv
|
||||
|
||||
def test_get_no_executable(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
|
Loading…
x
Reference in New Issue
Block a user