Remove the old cache-key generation code.

We have switched to the new cache-key generation code and
it is stable. Clean up the old code.

Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.

Testing: test workload.
PiperOrigin-RevId: 590378036
This commit is contained in:
jax authors 2023-12-12 16:33:45 -08:00
parent 05df8750ce
commit 32c99f627e
6 changed files with 23 additions and 305 deletions

View File

@ -17,7 +17,6 @@ import hashlib
import io
import logging
import os
import struct
import sys
from jax._src import config
@ -56,8 +55,7 @@ def get(module: ir.Module,
devices: np.ndarray,
compile_options: xla_client.CompileOptions,
backend: xla_client.Client,
compression_algorithm: str = "zstandard",
produce_original_cache_key: bool = True) -> str:
compression_algorithm: str = "zstandard") -> str:
"""Creates a hashed string to use as a key to the compilation cache.
Creates a cache key that is a hex-encoded string of a unique hash based on
@ -70,10 +68,6 @@ def get(module: ir.Module,
backend: description of the platform (e.g., TPU version)
compression_algorithm: a string representing the compression algorithm used
for the executable before persisting in the cache
produce_original_cache_key: if True, the original cache-key generation
algorithm is run, else the new one. This is transient; once the migration
is complete, this parameter and the original algorithm will be removed.
(New one not implemented as yet.)
Typical return value example:
'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'
@ -85,29 +79,14 @@ def get(module: ir.Module,
bytes(jaxlib_version_str.encode("utf-8")))),
("XLA flags",
lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())),
("compile_options",
lambda hash_obj: _hash_serialized_compile_options(
hash_obj, compile_options)),
("accelerator_config",
lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)),
("compression",
lambda hash_obj: _hash_string(hash_obj, compression_algorithm)),
]
if produce_original_cache_key:
entries.append(
("compile_options",
lambda hash_obj: _hash_compile_options(hash_obj, compile_options)),
)
entries.append(
("devices", lambda hash_obj: _hash_devices(hash_obj, devices)))
entries.append(
("the backend", lambda hash_obj: _hash_platform(hash_obj, backend)),
)
else:
entries.append(
("compile_options",
lambda hash_obj: _hash_serialized_compile_options(
hash_obj, compile_options)),
)
entries.append(
("accelerator_config",
lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)),
)
hash_obj = hashlib.sha256()
for name, hashfn in entries:
@ -217,98 +196,6 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj):
return hash_obj.update(compile_options_copy.SerializeAsString())
def _hash_compile_options(hash_obj, compile_options_obj):
expected_num_compile_options = 12
# Ignore private and built-in methods. These can unexpectedly change and lead
# to false positives, e.g. when different Python versions include different
# built-ins.
num_actual_options = len(
[x for x in dir(compile_options_obj) if not x.startswith("_")]
)
assert num_actual_options == expected_num_compile_options, (
"Unexpected number of CompileOption fields: "
f"{num_actual_options}. This likely: means that an extra "
"field was added, and this function needs to be updated."
)
if compile_options_obj.argument_layouts is not None:
map(
lambda shape: hash_obj.update(shape.to_serialized_proto()),
compile_options_obj.argument_layouts,
)
_hash_int(hash_obj, compile_options_obj.parameter_is_tupled_arguments)
_hash_executable_build_options(
hash_obj, compile_options_obj.executable_build_options
)
_hash_bool(hash_obj, compile_options_obj.tuple_arguments)
_hash_int(hash_obj, compile_options_obj.num_replicas)
_hash_int(hash_obj, compile_options_obj.num_partitions)
_hash_signed_int(hash_obj, compile_options_obj.profile_version)
if compile_options_obj.device_assignment is not None:
hash_obj.update(compile_options_obj.device_assignment.serialize())
_hash_bool(hash_obj, compile_options_obj.compile_portable_executable)
_hash_int(hash_obj, len(compile_options_obj.env_option_overrides))
for kv in compile_options_obj.env_option_overrides:
_hash_string(hash_obj, kv[0])
if isinstance(kv[1], str):
_hash_string(hash_obj, kv[1])
elif isinstance(kv[1], bool):
_hash_bool(hash_obj, kv[1])
elif isinstance(kv[1], int):
_hash_int(hash_obj, kv[1])
elif isinstance(kv[1], float):
_hash_float(hash_obj, kv[1])
else:
raise RuntimeError("Invalid type: %s" % repr(type(kv[1])))
def _hash_executable_build_options(hash_obj, executable_obj):
expected_options = 11
# Ignore private and built-in methods. These can unexpectedly change and lead
# to false positives, e.g. when different Python versions include different
# built-ins.
actual_options = len(
[x for x in dir(executable_obj) if not x.startswith("_")]
)
assert actual_options == expected_options, (
"Unexpected number of executable_build_options fields: "
f"{actual_options}, expected: {expected_options}. This likely means "
"that an extra field was added, and this function needs to be updated."
)
if executable_obj.result_layout is not None:
hash_obj.update(executable_obj.result_layout.to_serialized_proto())
_hash_int(hash_obj, executable_obj.num_replicas)
_hash_int(hash_obj, executable_obj.num_partitions)
_hash_debug_options(hash_obj, executable_obj.debug_options)
if executable_obj.device_assignment is not None:
hash_obj.update(executable_obj.device_assignment.serialize())
_hash_bool(hash_obj, executable_obj.use_spmd_partitioning)
_hash_bool(hash_obj, executable_obj.use_auto_spmd_partitioning)
if executable_obj.use_auto_spmd_partitioning:
if executable_obj.auto_spmd_partitioning_mesh_shape is not None:
_hash_int_list(hash_obj, executable_obj.auto_spmd_partitioning_mesh_shape)
if executable_obj.auto_spmd_partitioning_mesh_ids is not None:
_hash_int_list(hash_obj, executable_obj.auto_spmd_partitioning_mesh_ids)
_hash_bool_list(
hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output
)
if executable_obj.fdo_profile is not None:
_hash_string(hash_obj, executable_obj.fdo_profile)
def _hash_debug_options(hash_obj, debug_obj):
_hash_bool(hash_obj, debug_obj.xla_cpu_enable_fast_math)
_hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_infs)
_hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_nans)
_hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_division)
_hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_functions)
_hash_bool(hash_obj, debug_obj.xla_gpu_enable_fast_min_max)
_hash_int(hash_obj, debug_obj.xla_backend_optimization_level)
_hash_bool(hash_obj, debug_obj.xla_cpu_enable_xprof_traceme)
_hash_bool(hash_obj, debug_obj.xla_llvm_disable_expensive_passes)
_hash_bool(hash_obj, debug_obj.xla_test_all_input_layouts)
def _hash_platform(hash_obj, backend):
_hash_string(hash_obj, backend.platform)
_hash_string(hash_obj, backend.platform_version)
@ -366,33 +253,5 @@ def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]):
_hash_string(hash_obj, flag)
def _hash_int(hash_obj, int_var):
hash_obj.update(int_var.to_bytes(8, byteorder="big"))
def _hash_float(hash_obj, float_var):
hash_obj.update(struct.pack("d", float_var))
def _hash_signed_int(hash_obj, int_var):
hash_obj.update(int_var.to_bytes(8, byteorder="big", signed=True))
def _hash_bool(hash_obj, bool_var):
hash_obj.update(bool_var.to_bytes(1, byteorder="big"))
def _hash_string(hash_obj, str_var):
hash_obj.update(str_var.encode("utf-8").strip())
def _hash_bool_list(hash_obj, bool_list):
for b in bool_list:
_hash_bool(hash_obj, b)
_hash_int(hash_obj, len(bool_list))
def _hash_int_list(hash_obj, int_list):
for i in int_list:
_hash_int(hash_obj, i)
_hash_int(hash_obj, len(int_list))

View File

@ -167,10 +167,9 @@ def put_executable_and_time(
def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
backend, produce_original_cache_key: bool = True) -> str:
backend) -> str:
return cache_key.get(module, devices, compile_options, backend,
"zstandard" if zstandard is not None else "zlib",
produce_original_cache_key)
"zstandard" if zstandard is not None else "zlib")
def is_initialized() -> bool:

View File

@ -303,9 +303,7 @@ def compile_or_get_cached(
try:
cache_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend,
config.use_original_compilation_cache_key_generation.value,
)
computation, devices, compile_options, backend)
except xc._xla.XlaRuntimeError as ex:
logger.error("compile_or_get_cached: unable to generate cache key, "
"skipping the cache: %s", ex)
@ -321,18 +319,10 @@ def compile_or_get_cached(
assert retrieved_compile_time is not None
logger.debug("Persistent compilation cache hit for '%s'", module_name)
if config.use_original_compilation_cache_key_generation.value:
# TODO(b/293308239) Remove metrics for the original cache after the new
# compilation cache key implementation is fully rolled out.
monitoring.record_event('/jax/compilation_cache/cache_hits_original')
monitoring.record_event_duration_secs(
"/jax/compilation_cache/original_compile_time_saved_sec",
retrieved_compile_time - cache_retrieval_time)
else:
monitoring.record_event('/jax/compilation_cache/cache_hits')
monitoring.record_event_duration_secs(
'/jax/compilation_cache/compile_time_saved_sec',
retrieved_compile_time - cache_retrieval_time)
monitoring.record_event('/jax/compilation_cache/cache_hits')
monitoring.record_event_duration_secs(
'/jax/compilation_cache/compile_time_saved_sec',
retrieved_compile_time - cache_retrieval_time)
monitoring.record_event_duration_secs(
"/jax/compilation_cache/cache_retrieval_time_sec", cache_retrieval_time)

View File

@ -973,14 +973,6 @@ include_full_tracebacks_in_locations = define_bool_state(
),
)
use_original_compilation_cache_key_generation = define_bool_state(
name='jax_use_original_compilation_cache_key_generation',
default=False,
help="If true, use the original cache-key generation algorithm. This is "
"a transient flag; once the new cache-key generation algorithm is "
"deployed, this flag and the original cache-key generation algorithm "
"will be removed.")
enable_compilation_cache = define_bool_state(
name='jax_enable_compilation_cache',
default=True,

View File

@ -14,7 +14,6 @@
import hashlib
import os
import random
import sys
import unittest
@ -38,64 +37,6 @@ config.parse_flags_with_absl()
class CacheKeyTest(jtu.JaxTestCase):
def test_compile_options(self):
compile_options_not_filled = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
compile_options_filled = self.filled_compile_options()
filled_hash1 = self.get_hashed_value(
cache_key._hash_compile_options, compile_options_filled
)
filled_hash2 = self.get_hashed_value(
cache_key._hash_compile_options, compile_options_filled
)
not_filled_hash3 = self.get_hashed_value(
cache_key._hash_compile_options, compile_options_not_filled
)
self.assertEqual(filled_hash1, filled_hash2)
self.assertNotEqual(filled_hash1, not_filled_hash3)
def test_executable_build_options(self):
compile_options_not_filled = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
compile_options_filled = self.filled_compile_options()
filled_hash1 = self.get_hashed_value(
cache_key._hash_executable_build_options,
compile_options_filled.executable_build_options,
)
filled_hash2 = self.get_hashed_value(
cache_key._hash_executable_build_options,
compile_options_filled.executable_build_options,
)
not_filled_hash3 = self.get_hashed_value(
cache_key._hash_executable_build_options,
compile_options_not_filled.executable_build_options,
)
self.assertEqual(filled_hash1, filled_hash2)
self.assertNotEqual(filled_hash1, not_filled_hash3)
def test_debug_options(self):
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
hash1 = self.get_hashed_value(
cache_key._hash_debug_options,
compile_options.executable_build_options.debug_options,
)
hash2 = self.get_hashed_value(
cache_key._hash_debug_options,
compile_options.executable_build_options.debug_options,
)
self.assertEqual(hash1, hash2)
new_debug_options = self.create_new_debug_options(
compile_options.executable_build_options.debug_options
)
hash3 = self.get_hashed_value(
cache_key._hash_debug_options, new_debug_options
)
self.assertNotEqual(hash1, hash3)
def test_serialized_compile_options(self):
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
@ -156,29 +97,6 @@ class CacheKeyTest(jtu.JaxTestCase):
hash3 = self.get_hashed_value(cache_key._hash_platform, cpu_backend)
self.assertNotEqual(hash1, hash3)
def test_hash_int(self):
hash1 = self.get_hashed_value(cache_key._hash_int, 90)
hash2 = self.get_hashed_value(cache_key._hash_int, 8)
hash3 = self.get_hashed_value(cache_key._hash_int, 8)
self.assertEqual(hash2, hash3)
self.assertNotEqual(hash1, hash2)
def test_hash_signed_int(self):
hash1 = self.get_hashed_value(cache_key._hash_signed_int, 90)
hash2 = self.get_hashed_value(cache_key._hash_signed_int, -90)
hash3 = self.get_hashed_value(cache_key._hash_signed_int, -8)
hash4 = self.get_hashed_value(cache_key._hash_signed_int, -8)
self.assertEqual(hash3, hash4)
self.assertNotEqual(hash1, hash2)
self.assertNotEqual(hash1, hash3)
def test_hash_bool(self):
hash1 = self.get_hashed_value(cache_key._hash_bool, False)
hash2 = self.get_hashed_value(cache_key._hash_bool, True)
hash3 = self.get_hashed_value(cache_key._hash_bool, True)
self.assertEqual(hash2, hash3)
self.assertNotEqual(hash1, hash2)
def test_hash_string(self):
hash1 = self.get_hashed_value(cache_key._hash_string, "foo")
hash2 = self.get_hashed_value(cache_key._hash_string, "bar")
@ -320,19 +238,6 @@ class CacheKeyTest(jtu.JaxTestCase):
del os.environ["LIBTPU_INIT_ARGS"]
sys.argv = orig_argv
def create_new_debug_options(self, debug_options_obj):
debug_options_obj.xla_cpu_enable_fast_math = False
debug_options_obj.xla_cpu_fast_math_honor_infs = False
debug_options_obj.xla_cpu_fast_math_honor_nans = False
debug_options_obj.xla_cpu_fast_math_honor_division = False
debug_options_obj.xla_cpu_fast_math_honor_functions = False
debug_options_obj.xla_gpu_enable_fast_min_max = False
debug_options_obj.xla_backend_optimization_level = random.randint(0, 10)
debug_options_obj.xla_cpu_enable_xprof_traceme = False
debug_options_obj.xla_llvm_disable_expensive_passes = False
debug_options_obj.xla_test_all_input_layouts = False
return debug_options_obj
def filled_compile_options(self):
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = 1

View File

@ -302,14 +302,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
files_in_cache = len(os.listdir(tmpdir))
self.assertEqual(files_in_cache, 1)
# TODO(b/293308239) Remove the parameters after the new compilation cache key
# implementation is enabled.
@parameterized.parameters(True, False)
def test_cache_saving_metric(self, use_original):
def test_cache_saving_metric(self):
with (
tempfile.TemporaryDirectory() as tmpdir,
config.persistent_cache_min_compile_time_secs(2),
config.use_original_compilation_cache_key_generation(use_original),
):
cc.initialize_cache(tmpdir)
@ -327,13 +323,8 @@ class CompilationCacheTest(jtu.JaxTestCase):
jit(lambda x: x + 1)(1)
self.assertNotIn(
"/jax/compilation_cache/cache_retrieval_time_sec", durations)
if use_original:
self.assertNotIn(
"/jax/compilation_cache/original_compile_time_saved_sec",
durations)
else:
self.assertNotIn(
"/jax/compilation_cache/compile_time_saved_sec", durations)
self.assertNotIn(
"/jax/compilation_cache/compile_time_saved_sec", durations)
# Mock time to create a long compilation time, metrics incremented with
# a cache hit.
@ -343,16 +334,8 @@ class CompilationCacheTest(jtu.JaxTestCase):
jit(lambda x: x + 2)(1)
self.assertGreater(
durations["/jax/compilation_cache/cache_retrieval_time_sec"], 0)
if use_original:
self.assertGreater(
durations[
"/jax/compilation_cache/original_compile_time_saved_sec"
], 0)
else:
if xla_bridge.using_pjrt_c_api():
raise SkipTest("PJRT C API not supported yet.")
self.assertGreater(
durations["/jax/compilation_cache/compile_time_saved_sec"], 0)
self.assertGreater(
durations["/jax/compilation_cache/compile_time_saved_sec"], 0)
def test_task_using_cache_metric(self):
with tempfile.TemporaryDirectory() as tmpdir:
@ -403,15 +386,11 @@ class CompilationCacheTest(jtu.JaxTestCase):
- previous_counts["/jax/compilation_cache/cache_misses"],
2)
# TODO(b/293308239) Remove the parameters after the new compilation cache key
# implementation is enabled.
@parameterized.parameters(True, False)
def test_cache_hits_metric(self, use_original):
def test_cache_hits_metric(self):
previous_counts = Counter(_counts)
with (
tempfile.TemporaryDirectory() as tmpdir,
config.persistent_cache_min_compile_time_secs(2),
config.use_original_compilation_cache_key_generation(use_original),
):
cc.initialize_cache(tmpdir)
@ -420,16 +399,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
jit(lambda x: x + 1)(1)
jit(lambda x: x + 1)(1)
if use_original:
self.assertEqual(
_counts["/jax/compilation_cache/cache_hits_original"]
- previous_counts["/jax/compilation_cache/cache_hits_original"],
1)
else:
self.assertEqual(
_counts["/jax/compilation_cache/cache_hits"]
- previous_counts["/jax/compilation_cache/cache_hits"],
1)
self.assertEqual(
_counts["/jax/compilation_cache/cache_hits"]
- previous_counts["/jax/compilation_cache/cache_hits"],
1)
@jtu.with_config(