mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Move implementation of compilation cache out of jax/experimental and into jax/_src.
Use a Protocol instead of an abstract base class for the CacheInterface since it allows us to use one fewer file. No functional change intended. PiperOrigin-RevId: 524855263
This commit is contained in:
parent
febd339742
commit
017548c40b
32
jax/BUILD
32
jax/BUILD
@ -146,10 +146,7 @@ py_library_providing_imports_info(
|
||||
"experimental/shard_map.py",
|
||||
# until checkify is moved out of experimental
|
||||
"experimental/checkify.py",
|
||||
# to avoid circular dependencies
|
||||
"experimental/compilation_cache/compilation_cache.py",
|
||||
"experimental/compilation_cache/gfile_cache.py",
|
||||
"experimental/compilation_cache/cache_interface.py",
|
||||
],
|
||||
lib_rule = pytype_library,
|
||||
pytype_srcs = glob(
|
||||
@ -165,6 +162,7 @@ py_library_providing_imports_info(
|
||||
":api_util",
|
||||
":basearray",
|
||||
":cloud_tpu_init",
|
||||
":compilation_cache_internal",
|
||||
":config",
|
||||
":core",
|
||||
":custom_api_util",
|
||||
@ -246,6 +244,24 @@ pytype_strict_library(
|
||||
srcs = ["_src/cloud_tpu_init.py"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "compilation_cache_internal",
|
||||
srcs = ["_src/compilation_cache.py"],
|
||||
visibility = [":internal"] + jax_visibility("compilation_cache"),
|
||||
deps = [
|
||||
":compilation_cache_interface",
|
||||
":gfile_cache",
|
||||
":path",
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "compilation_cache_interface",
|
||||
srcs = ["_src/compilation_cache_interface.py"],
|
||||
deps = [":path"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "config",
|
||||
srcs = ["_src/config.py"],
|
||||
@ -300,6 +316,15 @@ pytype_strict_library(
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "gfile_cache",
|
||||
srcs = ["_src/gfile_cache.py"],
|
||||
deps = [
|
||||
":compilation_cache_interface",
|
||||
":path",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "iree",
|
||||
srcs = ["_src/iree.py"],
|
||||
@ -664,7 +689,6 @@ pytype_library(
|
||||
name = "compilation_cache",
|
||||
srcs = [
|
||||
"experimental/compilation_cache/compilation_cache.py",
|
||||
"experimental/compilation_cache/gfile_cache.py",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":jax"],
|
||||
|
285
jax/_src/compilation_cache.py
Normal file
285
jax/_src/compilation_cache.py
Normal file
@ -0,0 +1,285 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
from jax._src import path as pathlib
|
||||
from jax._src.compilation_cache_interface import CacheInterface
|
||||
from jax._src.gfile_cache import GFileCache
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import version_str as jaxlib_version_str
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_cache: Optional[CacheInterface] = None
|
||||
|
||||
|
||||
def initialize_cache(path):
|
||||
"""Creates a global cache object. Should only be called once per process.
|
||||
|
||||
Will throw an assertion error if called a second time with a different path.
|
||||
|
||||
Args:
|
||||
path: path for the cache directory.
|
||||
|
||||
"""
|
||||
global _cache
|
||||
if _cache is not None and _cache._path == pathlib.Path(path):
|
||||
logger.warning("Cache already previously initialized at %s", _cache._path)
|
||||
return
|
||||
|
||||
assert _cache is None, f"The cache path has already been initialized to {_cache._path}"
|
||||
_cache = GFileCache(path)
|
||||
logger.warning("Initialized persistent compilation cache at %s", path)
|
||||
|
||||
|
||||
def get_executable(xla_computation, compile_options,
|
||||
backend) -> Optional[xla_client.LoadedExecutable]:
|
||||
"""Returns the cached executable if present, or None otherwise."""
|
||||
assert _cache is not None, "initialize_cache must be called before you can call get_executable()"
|
||||
cache_key = get_cache_key(xla_computation, compile_options, backend)
|
||||
xla_executable_serialized = _cache.get(cache_key)
|
||||
if not xla_executable_serialized:
|
||||
return None
|
||||
xla_executable_deserialized = backend.deserialize_executable(
|
||||
xla_executable_serialized,
|
||||
compile_options)
|
||||
return xla_executable_deserialized
|
||||
|
||||
def put_executable(module_name, xla_computation, compile_options,
|
||||
executable: xla_client.LoadedExecutable, backend):
|
||||
"""Adds 'executable' to the cache, possibly evicting older entries."""
|
||||
assert _cache is not None, "initialize_cache must be called before you can call put_executable()"
|
||||
cache_key = get_cache_key(xla_computation, compile_options, backend)
|
||||
logger.info('Writing %s to persistent compilation cache with key %s.',
|
||||
module_name, cache_key)
|
||||
serialized_executable = backend.serialize_executable(executable)
|
||||
_cache.put(cache_key, serialized_executable)
|
||||
|
||||
def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
# Log the hash of just this entry
|
||||
fresh_hash_obj = hashlib.sha256()
|
||||
hashfn(fresh_hash_obj)
|
||||
logger.debug("get_cache_key hash of serialized %s: %s", last_serialized,
|
||||
fresh_hash_obj.digest().hex())
|
||||
# Log the cumulative hash
|
||||
logger.debug("get_cache_key hash after serializing %s: %s",
|
||||
last_serialized, hash_obj.digest().hex())
|
||||
|
||||
def get_cache_key(xla_computation, compile_options, backend) -> str:
|
||||
"""Creates a hashed string to use as a key to the compilation cache.
|
||||
|
||||
get_cache_key takes in the xla_computation and compile_options of a program and hashes
|
||||
all the components into a uniuqe byte string. This byte string is returned as a regular
|
||||
string that is 256 characters long.
|
||||
|
||||
Typical return value example:
|
||||
'14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'
|
||||
"""
|
||||
entries = [
|
||||
("computation",
|
||||
lambda hash_obj: _hash_computation(hash_obj, xla_computation)),
|
||||
("compile_options",
|
||||
lambda hash_obj: _hash_compile_options(hash_obj, compile_options)),
|
||||
("jax_lib version",
|
||||
lambda hash_obj: hash_obj.update(bytes(jaxlib_version_str.encode('utf-8')))),
|
||||
("the backend", lambda hash_obj: _hash_platform(hash_obj, backend)),
|
||||
("XLA flags", _hash_xla_flags),
|
||||
]
|
||||
|
||||
hash_obj = hashlib.sha256()
|
||||
for name, hashfn in entries:
|
||||
hashfn(hash_obj)
|
||||
_log_cache_key_hash(hash_obj, name, hashfn)
|
||||
return hash_obj.digest().hex()
|
||||
|
||||
def _hash_computation(hash_obj, xla_computation):
|
||||
# The HLO op_name metadata sometimes includes Python function pointers,
|
||||
# which cause spurious cache misses. Scrub anything that looks like a
|
||||
# function pointer. Example op_name metadata:
|
||||
# op_name="jit(s)/custom_jvp_call_jaxpr
|
||||
# [ jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f3fa30f0940>\n
|
||||
# num_consts=0 ]"
|
||||
# TODO(skye): in theory this could cause us to scrub meaningful binary proto
|
||||
# data. Do something more robust.
|
||||
if isinstance(xla_computation, bytes):
|
||||
serialized_hlo = xla_computation # MLIR module bytecode
|
||||
elif isinstance(xla_computation, str):
|
||||
serialized_hlo = xla_computation.encode() # MLIR module text
|
||||
else:
|
||||
raise TypeError(f"Unknown computation type {type(xla_computation)}")
|
||||
scrubbed_hlo = re.sub(b" at 0x[a-f0-9]+>", b" at 0x...>", serialized_hlo)
|
||||
hash_obj.update(scrubbed_hlo)
|
||||
|
||||
def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
if xla_extension_version >= 145:
|
||||
expected_num_compile_options = 12
|
||||
else:
|
||||
expected_num_compile_options = 11
|
||||
# Ignore private and built-in methods. These can unexpectedly change and lead
|
||||
# to false positives, e.g. when different Python versions include different
|
||||
# built-ins.
|
||||
num_actual_options = len(
|
||||
[x for x in dir(compile_options_obj) if not x.startswith("_")])
|
||||
assert num_actual_options == expected_num_compile_options, (
|
||||
"Unexpected number of CompileOption fields: "
|
||||
f"{num_actual_options}. This likely: means that an extra "
|
||||
"field was added, and this function needs to be updated."
|
||||
)
|
||||
|
||||
if compile_options_obj.argument_layouts is not None:
|
||||
map(lambda shape: hash_obj.update(shape.to_serialized_proto()),
|
||||
compile_options_obj.argument_layouts)
|
||||
_hash_int(hash_obj, compile_options_obj.parameter_is_tupled_arguments)
|
||||
_hash_executable_build_options(hash_obj, compile_options_obj.executable_build_options)
|
||||
_hash_bool(hash_obj, compile_options_obj.tuple_arguments)
|
||||
_hash_int(hash_obj, compile_options_obj.num_replicas)
|
||||
_hash_int(hash_obj, compile_options_obj.num_partitions)
|
||||
_hash_int(hash_obj, compile_options_obj.profile_version)
|
||||
if compile_options_obj.device_assignment is not None:
|
||||
hash_obj.update(compile_options_obj.device_assignment.serialize())
|
||||
_hash_bool(hash_obj, compile_options_obj.compile_portable_executable)
|
||||
if xla_extension_version >= 145:
|
||||
_hash_int(hash_obj, len(compile_options_obj.env_option_overrides))
|
||||
for kv in compile_options_obj.env_option_overrides:
|
||||
_hash_string(hash_obj, kv[0])
|
||||
if isinstance(kv[1], str):
|
||||
_hash_string(hash_obj, kv[1])
|
||||
elif isinstance(kv[1], bool):
|
||||
_hash_bool(hash_obj, kv[1])
|
||||
else:
|
||||
raise RuntimeError("Invalid type: %s" % repr(type(kv[1])))
|
||||
|
||||
def _hash_executable_build_options(hash_obj, executable_obj):
|
||||
expected_options = 10
|
||||
# Ignore private and built-in methods. These can unexpectedly change and lead
|
||||
# to false positives, e.g. when different Python versions include different
|
||||
# built-ins.
|
||||
actual_options = len(
|
||||
[x for x in dir(executable_obj) if not x.startswith('_')])
|
||||
assert actual_options == expected_options, (
|
||||
f"Unexpected number of executable_build_options fields: "
|
||||
f"{actual_options}, expected: {expected_options}. This likely means "
|
||||
"that an extra field was added, and this function needs to be updated.")
|
||||
if executable_obj.result_layout is not None:
|
||||
hash_obj.update(executable_obj.result_layout.to_serialized_proto())
|
||||
_hash_int(hash_obj, executable_obj.num_replicas)
|
||||
_hash_int(hash_obj, executable_obj.num_partitions)
|
||||
_hash_debug_options(hash_obj, executable_obj.debug_options)
|
||||
if executable_obj.device_assignment is not None:
|
||||
hash_obj.update(executable_obj.device_assignment.serialize())
|
||||
_hash_bool(hash_obj, executable_obj.use_spmd_partitioning)
|
||||
_hash_bool(hash_obj, executable_obj.use_auto_spmd_partitioning)
|
||||
if executable_obj.use_auto_spmd_partitioning:
|
||||
if executable_obj.auto_spmd_partitioning_mesh_shape is not None:
|
||||
_hash_int_list(hash_obj, executable_obj.auto_spmd_partitioning_mesh_shape)
|
||||
if executable_obj.auto_spmd_partitioning_mesh_ids is not None:
|
||||
_hash_int_list(hash_obj, executable_obj.auto_spmd_partitioning_mesh_ids)
|
||||
_hash_bool_list(hash_obj,
|
||||
executable_obj.allow_spmd_sharding_propagation_to_output)
|
||||
|
||||
def _hash_debug_options(hash_obj, debug_obj):
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_enable_fast_math)
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_infs)
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_nans)
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_division)
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_functions)
|
||||
_hash_bool(hash_obj, debug_obj.xla_gpu_enable_fast_min_max)
|
||||
_hash_int(hash_obj, debug_obj.xla_backend_optimization_level)
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_enable_xprof_traceme)
|
||||
_hash_bool(hash_obj, debug_obj.xla_llvm_disable_expensive_passes)
|
||||
_hash_bool(hash_obj, debug_obj.xla_test_all_input_layouts)
|
||||
|
||||
def _hash_platform(hash_obj, backend):
|
||||
_hash_string(hash_obj, backend.platform)
|
||||
_hash_string(hash_obj, backend.platform_version)
|
||||
_hash_string(hash_obj, backend.runtime_type)
|
||||
|
||||
_xla_flags_to_exclude_from_cache_key = [
|
||||
"--xla_dump_compress_protos",
|
||||
"--xla_dump_module_metadata",
|
||||
"--xla_dump_max_hlo_modules",
|
||||
"--xla_dump_include_timestamp",
|
||||
"--xla_dump_hlo_pass_re",
|
||||
"--xla_dump_hlo_module_re",
|
||||
"--xla_dump_hlo_snapshots",
|
||||
"--xla_dump_fusion_visualization",
|
||||
"--xla_dump_hlo_as_url",
|
||||
"--xla_dump_hlo_as_proto",
|
||||
"--xla_dump_hlo_as_text",
|
||||
"--xla_dump_to",
|
||||
"--xla_force_host_platform_device_count",
|
||||
"--xla_dump_disable_metadata",
|
||||
"--xla_dump_hlo_pipeline_re",
|
||||
"--xla_tpu_sdc_checker_streamz_metric",
|
||||
"--xla_tpu_sdc_checker_enable_sdc_event_callbacks",
|
||||
]
|
||||
|
||||
extra_flag_prefixes_to_include_in_cache_key: List[str] = []
|
||||
|
||||
def _hash_xla_flags(hash_obj):
|
||||
xla_flags = []
|
||||
|
||||
xla_flags_env_var = os.getenv("XLA_FLAGS")
|
||||
if xla_flags_env_var:
|
||||
xla_flags.extend(xla_flags_env_var.split())
|
||||
|
||||
for arg in sys.argv:
|
||||
if arg.startswith("--xla") or any(
|
||||
arg.startswith(p) for p in extra_flag_prefixes_to_include_in_cache_key):
|
||||
xla_flags.append(arg)
|
||||
|
||||
# N.B. all XLA flags that take an argument must use '=' and not a space
|
||||
# (e.g. --xla_force_host_platform_device_count=8) (I think).
|
||||
for flag in xla_flags:
|
||||
if flag.split('=')[0] in _xla_flags_to_exclude_from_cache_key:
|
||||
logger.debug("Not including XLA flag in cache key: %s", flag)
|
||||
continue
|
||||
logger.debug("Including XLA flag in cache key: %s", flag)
|
||||
_hash_string(hash_obj, flag)
|
||||
|
||||
def _hash_int(hash_obj, int_var):
|
||||
hash_obj.update(int_var.to_bytes(8, byteorder='big'))
|
||||
|
||||
def _hash_bool(hash_obj, bool_var):
|
||||
hash_obj.update(bool_var.to_bytes(1, byteorder='big'))
|
||||
|
||||
def _hash_string(hash_obj, str_var):
|
||||
hash_obj.update(str_var.encode('utf-8').strip())
|
||||
|
||||
def _hash_bool_list(hash_obj, bool_list):
|
||||
for b in bool_list:
|
||||
_hash_bool(hash_obj, b)
|
||||
_hash_int(hash_obj, len(bool_list))
|
||||
|
||||
def _hash_int_list(hash_obj, int_list):
|
||||
for i in int_list:
|
||||
_hash_int(hash_obj, i)
|
||||
_hash_int(hash_obj, len(int_list))
|
||||
|
||||
def is_initialized():
|
||||
return _cache is not None
|
||||
|
||||
def reset_cache():
|
||||
global _cache
|
||||
assert is_initialized()
|
||||
logger.info("Resetting cache at %s.", _cache._path)
|
||||
_cache = None
|
@ -14,7 +14,12 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from jax._src import path as pathlib
|
||||
|
||||
|
||||
class CacheInterface(ABC):
|
||||
_path: pathlib.Path
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str):
|
||||
pass
|
@ -31,6 +31,7 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import compilation_cache
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
@ -466,9 +467,6 @@ def _dump_ir_to_file(name: str, ir: str):
|
||||
|
||||
def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
||||
host_callbacks):
|
||||
# Avoid import cycle between jax and jax.experimental
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
|
||||
sym_name = computation.operation.attributes['sym_name']
|
||||
module_name = ir.StringAttr(sym_name).value
|
||||
|
||||
@ -490,7 +488,8 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
||||
# (b/233850967) CPU caching can be enabled if XLA Runtime is enabled.
|
||||
if "--xla_cpu_use_xla_runtime=true" in os.environ.get("XLA_FLAGS", ""):
|
||||
supported_platforms.append("cpu")
|
||||
if cc.is_initialized() and backend.platform in supported_platforms:
|
||||
if (compilation_cache.is_initialized() and
|
||||
backend.platform in supported_platforms):
|
||||
cached_executable = _cache_read(serialized_computation, module_name,
|
||||
compile_options, backend)
|
||||
if cached_executable is not None:
|
||||
@ -513,11 +512,9 @@ def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
|
||||
compile_options: CompileOptions,
|
||||
backend: Backend) -> Optional[xc.LoadedExecutable]:
|
||||
"""Looks up `computation` in the persistent compilation cache."""
|
||||
# Avoid import cycle between jax and jax.experimental
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
|
||||
try:
|
||||
return cc.get_executable(computation, compile_options, backend)
|
||||
return compilation_cache.get_executable(computation, compile_options,
|
||||
backend)
|
||||
except Exception as ex:
|
||||
if config.jax_raise_persistent_cache_errors:
|
||||
raise
|
||||
@ -533,9 +530,6 @@ def _cache_write(serialized_computation: Union[str, bytes, ir.Module],
|
||||
backend: Backend, compiled: xc.LoadedExecutable,
|
||||
host_callbacks: List[Any]):
|
||||
"""Writes `serialized_computation` to the persistent compilation cache."""
|
||||
# Avoid import cycle between jax and jax.experimental
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
|
||||
if host_callbacks:
|
||||
logger.info(
|
||||
"Not writing persistent cache entry for '%s' because it uses host "
|
||||
@ -557,8 +551,8 @@ def _cache_write(serialized_computation: Union[str, bytes, ir.Module],
|
||||
compile_time_secs)
|
||||
|
||||
try:
|
||||
cc.put_executable(module_name, serialized_computation, compile_options,
|
||||
compiled, backend)
|
||||
compilation_cache.put_executable(module_name, serialized_computation,
|
||||
compile_options, compiled, backend)
|
||||
except Exception as ex:
|
||||
if config.jax_raise_persistent_cache_errors:
|
||||
raise
|
||||
|
@ -14,8 +14,8 @@
|
||||
|
||||
import os
|
||||
|
||||
from jax.experimental.compilation_cache.cache_interface import CacheInterface
|
||||
from jax._src import path as pathlib
|
||||
from jax._src.compilation_cache_interface import CacheInterface
|
||||
|
||||
class GFileCache(CacheInterface):
|
||||
|
@ -12,273 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
from jax.experimental.compilation_cache.gfile_cache import GFileCache
|
||||
from jax._src import path as pathlib
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import version_str as jaxlib_version_str
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
_cache = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_cache(path):
|
||||
"""Creates a global cache object. Should only be called once per process.
|
||||
|
||||
Will throw an assertion error if called a second time with a different path.
|
||||
|
||||
Args:
|
||||
path: path for the cache directory.
|
||||
|
||||
"""
|
||||
global _cache
|
||||
if _cache is not None and _cache._path == pathlib.Path(path):
|
||||
logger.warning("Cache already previously initialized at %s", _cache._path)
|
||||
return
|
||||
|
||||
assert _cache == None, f"The cache path has already been initialized to {_cache._path}"
|
||||
_cache = GFileCache(path)
|
||||
logger.warning("Initialized persistent compilation cache at %s", path)
|
||||
|
||||
|
||||
def get_executable(xla_computation, compile_options,
|
||||
backend) -> Optional[xla_client.LoadedExecutable]:
|
||||
"""Returns the cached executable if present, or None otherwise."""
|
||||
assert _cache is not None, "initialize_cache must be called before you can call get_executable()"
|
||||
cache_key = get_cache_key(xla_computation, compile_options, backend)
|
||||
xla_executable_serialized = _cache.get(cache_key)
|
||||
if not xla_executable_serialized:
|
||||
return None
|
||||
xla_executable_deserialized = backend.deserialize_executable(
|
||||
xla_executable_serialized,
|
||||
compile_options)
|
||||
return xla_executable_deserialized
|
||||
|
||||
def put_executable(module_name, xla_computation, compile_options,
|
||||
executable: xla_client.LoadedExecutable, backend):
|
||||
"""Adds 'executable' to the cache, possibly evicting older entries."""
|
||||
assert _cache is not None, "initialize_cache must be called before you can call put_executable()"
|
||||
cache_key = get_cache_key(xla_computation, compile_options, backend)
|
||||
logger.info('Writing %s to persistent compilation cache with key %s.',
|
||||
module_name, cache_key)
|
||||
serialized_executable = backend.serialize_executable(executable)
|
||||
_cache.put(cache_key, serialized_executable)
|
||||
|
||||
def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
# Log the hash of just this entry
|
||||
fresh_hash_obj = hashlib.sha256()
|
||||
hashfn(fresh_hash_obj)
|
||||
logger.debug("get_cache_key hash of serialized %s: %s", last_serialized,
|
||||
fresh_hash_obj.digest().hex())
|
||||
# Log the cumulative hash
|
||||
logger.debug("get_cache_key hash after serializing %s: %s",
|
||||
last_serialized, hash_obj.digest().hex())
|
||||
|
||||
def get_cache_key(xla_computation, compile_options, backend) -> str:
|
||||
"""Creates a hashed string to use as a key to the compilation cache.
|
||||
|
||||
get_cache_key takes in the xla_computation and compile_options of a program and hashes
|
||||
all the components into a uniuqe byte string. This byte string is returned as a regular
|
||||
string that is 256 characters long.
|
||||
|
||||
Typical return value example:
|
||||
'14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'
|
||||
"""
|
||||
entries = [
|
||||
("computation",
|
||||
lambda hash_obj: _hash_computation(hash_obj, xla_computation)),
|
||||
("compile_options",
|
||||
lambda hash_obj: _hash_compile_options(hash_obj, compile_options)),
|
||||
("jax_lib version",
|
||||
lambda hash_obj: hash_obj.update(bytes(jaxlib_version_str.encode('utf-8')))),
|
||||
("the backend", lambda hash_obj: _hash_platform(hash_obj, backend)),
|
||||
("XLA flags", _hash_xla_flags),
|
||||
]
|
||||
|
||||
hash_obj = hashlib.sha256()
|
||||
for name, hashfn in entries:
|
||||
hashfn(hash_obj)
|
||||
_log_cache_key_hash(hash_obj, name, hashfn)
|
||||
return hash_obj.digest().hex()
|
||||
|
||||
def _hash_computation(hash_obj, xla_computation):
|
||||
# The HLO op_name metadata sometimes includes Python function pointers,
|
||||
# which cause spurious cache misses. Scrub anything that looks like a
|
||||
# function pointer. Example op_name metadata:
|
||||
# op_name="jit(s)/custom_jvp_call_jaxpr
|
||||
# [ jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f3fa30f0940>\n
|
||||
# num_consts=0 ]"
|
||||
# TODO(skye): in theory this could cause us to scrub meaningful binary proto
|
||||
# data. Do something more robust.
|
||||
if isinstance(xla_computation, bytes):
|
||||
serialized_hlo = xla_computation # MLIR module bytecode
|
||||
elif isinstance(xla_computation, str):
|
||||
serialized_hlo = xla_computation.encode() # MLIR module text
|
||||
else:
|
||||
raise TypeError(f"Unknown computation type {type(xla_computation)}")
|
||||
scrubbed_hlo = re.sub(b" at 0x[a-f0-9]+>", b" at 0x...>", serialized_hlo)
|
||||
hash_obj.update(scrubbed_hlo)
|
||||
|
||||
def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
if xla_extension_version >= 145:
|
||||
expected_num_compile_options = 12
|
||||
else:
|
||||
expected_num_compile_options = 11
|
||||
# Ignore private and built-in methods. These can unexpectedly change and lead
|
||||
# to false positives, e.g. when different Python versions include different
|
||||
# built-ins.
|
||||
num_actual_options = len(
|
||||
[x for x in dir(compile_options_obj) if not x.startswith("_")])
|
||||
assert num_actual_options == expected_num_compile_options, (
|
||||
"Unexpected number of CompileOption fields: "
|
||||
f"{num_actual_options}. This likely: means that an extra "
|
||||
"field was added, and this function needs to be updated."
|
||||
from jax._src.compilation_cache import (
|
||||
is_initialized as is_initialized,
|
||||
initialize_cache as initialize_cache,
|
||||
reset_cache as reset_cache,
|
||||
)
|
||||
|
||||
if compile_options_obj.argument_layouts is not None:
|
||||
map(lambda shape: hash_obj.update(shape.to_serialized_proto()),
|
||||
compile_options_obj.argument_layouts)
|
||||
_hash_int(hash_obj, compile_options_obj.parameter_is_tupled_arguments)
|
||||
_hash_executable_build_options(hash_obj, compile_options_obj.executable_build_options)
|
||||
_hash_bool(hash_obj, compile_options_obj.tuple_arguments)
|
||||
_hash_int(hash_obj, compile_options_obj.num_replicas)
|
||||
_hash_int(hash_obj, compile_options_obj.num_partitions)
|
||||
_hash_int(hash_obj, compile_options_obj.profile_version)
|
||||
if compile_options_obj.device_assignment is not None:
|
||||
hash_obj.update(compile_options_obj.device_assignment.serialize())
|
||||
_hash_bool(hash_obj, compile_options_obj.compile_portable_executable)
|
||||
if xla_extension_version >= 145:
|
||||
_hash_int(hash_obj, len(compile_options_obj.env_option_overrides))
|
||||
for kv in compile_options_obj.env_option_overrides:
|
||||
_hash_string(hash_obj, kv[0])
|
||||
if isinstance(kv[1], str):
|
||||
_hash_string(hash_obj, kv[1])
|
||||
elif isinstance(kv[1], bool):
|
||||
_hash_bool(hash_obj, kv[1])
|
||||
else:
|
||||
raise RuntimeError("Invalid type: %s" % repr(type(kv[1])))
|
||||
|
||||
def _hash_executable_build_options(hash_obj, executable_obj):
|
||||
expected_options = 10
|
||||
# Ignore private and built-in methods. These can unexpectedly change and lead
|
||||
# to false positives, e.g. when different Python versions include different
|
||||
# built-ins.
|
||||
actual_options = len(
|
||||
[x for x in dir(executable_obj) if not x.startswith('_')])
|
||||
assert actual_options == expected_options, (
|
||||
f"Unexpected number of executable_build_options fields: "
|
||||
f"{actual_options}, expected: {expected_options}. This likely means "
|
||||
"that an extra field was added, and this function needs to be updated.")
|
||||
if executable_obj.result_layout is not None:
|
||||
hash_obj.update(executable_obj.result_layout.to_serialized_proto())
|
||||
_hash_int(hash_obj, executable_obj.num_replicas)
|
||||
_hash_int(hash_obj, executable_obj.num_partitions)
|
||||
_hash_debug_options(hash_obj, executable_obj.debug_options)
|
||||
if executable_obj.device_assignment is not None:
|
||||
hash_obj.update(executable_obj.device_assignment.serialize())
|
||||
_hash_bool(hash_obj, executable_obj.use_spmd_partitioning)
|
||||
_hash_bool(hash_obj, executable_obj.use_auto_spmd_partitioning)
|
||||
if executable_obj.use_auto_spmd_partitioning:
|
||||
if executable_obj.auto_spmd_partitioning_mesh_shape is not None:
|
||||
_hash_int_list(hash_obj, executable_obj.auto_spmd_partitioning_mesh_shape)
|
||||
if executable_obj.auto_spmd_partitioning_mesh_ids is not None:
|
||||
_hash_int_list(hash_obj, executable_obj.auto_spmd_partitioning_mesh_ids)
|
||||
_hash_bool_list(hash_obj,
|
||||
executable_obj.allow_spmd_sharding_propagation_to_output)
|
||||
|
||||
def _hash_debug_options(hash_obj, debug_obj):
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_enable_fast_math)
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_infs)
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_nans)
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_division)
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_functions)
|
||||
_hash_bool(hash_obj, debug_obj.xla_gpu_enable_fast_min_max)
|
||||
_hash_int(hash_obj, debug_obj.xla_backend_optimization_level)
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_enable_xprof_traceme)
|
||||
_hash_bool(hash_obj, debug_obj.xla_llvm_disable_expensive_passes)
|
||||
_hash_bool(hash_obj, debug_obj.xla_test_all_input_layouts)
|
||||
|
||||
def _hash_platform(hash_obj, backend):
|
||||
_hash_string(hash_obj, backend.platform)
|
||||
_hash_string(hash_obj, backend.platform_version)
|
||||
_hash_string(hash_obj, backend.runtime_type)
|
||||
|
||||
_xla_flags_to_exclude_from_cache_key = [
|
||||
"--xla_dump_compress_protos",
|
||||
"--xla_dump_module_metadata",
|
||||
"--xla_dump_max_hlo_modules",
|
||||
"--xla_dump_include_timestamp",
|
||||
"--xla_dump_hlo_pass_re",
|
||||
"--xla_dump_hlo_module_re",
|
||||
"--xla_dump_hlo_snapshots",
|
||||
"--xla_dump_fusion_visualization",
|
||||
"--xla_dump_hlo_as_url",
|
||||
"--xla_dump_hlo_as_proto",
|
||||
"--xla_dump_hlo_as_text",
|
||||
"--xla_dump_to",
|
||||
"--xla_force_host_platform_device_count",
|
||||
"--xla_dump_disable_metadata",
|
||||
"--xla_dump_hlo_pipeline_re",
|
||||
"--xla_tpu_sdc_checker_streamz_metric",
|
||||
"--xla_tpu_sdc_checker_enable_sdc_event_callbacks",
|
||||
]
|
||||
|
||||
extra_flag_prefixes_to_include_in_cache_key: List[str] = []
|
||||
|
||||
def _hash_xla_flags(hash_obj):
|
||||
xla_flags = []
|
||||
|
||||
xla_flags_env_var = os.getenv("XLA_FLAGS")
|
||||
if xla_flags_env_var:
|
||||
xla_flags.extend(xla_flags_env_var.split())
|
||||
|
||||
for arg in sys.argv:
|
||||
if arg.startswith("--xla") or any(
|
||||
arg.startswith(p) for p in extra_flag_prefixes_to_include_in_cache_key):
|
||||
xla_flags.append(arg)
|
||||
|
||||
# N.B. all XLA flags that take an argument must use '=' and not a space
|
||||
# (e.g. --xla_force_host_platform_device_count=8) (I think).
|
||||
for flag in xla_flags:
|
||||
if flag.split('=')[0] in _xla_flags_to_exclude_from_cache_key:
|
||||
logger.debug("Not including XLA flag in cache key: %s", flag)
|
||||
continue
|
||||
logger.debug("Including XLA flag in cache key: %s", flag)
|
||||
_hash_string(hash_obj, flag)
|
||||
|
||||
def _hash_int(hash_obj, int_var):
|
||||
hash_obj.update(int_var.to_bytes(8, byteorder='big'))
|
||||
|
||||
def _hash_bool(hash_obj, bool_var):
|
||||
hash_obj.update(bool_var.to_bytes(1, byteorder='big'))
|
||||
|
||||
def _hash_string(hash_obj, str_var):
|
||||
hash_obj.update(str_var.encode('utf-8').strip())
|
||||
|
||||
def _hash_bool_list(hash_obj, bool_list):
|
||||
for b in bool_list:
|
||||
_hash_bool(hash_obj, b)
|
||||
_hash_int(hash_obj, len(bool_list))
|
||||
|
||||
def _hash_int_list(hash_obj, int_list):
|
||||
for i in int_list:
|
||||
_hash_int(hash_obj, i)
|
||||
_hash_int(hash_obj, len(int_list))
|
||||
|
||||
def is_initialized():
|
||||
return _cache is not None
|
||||
|
||||
def reset_cache():
|
||||
global _cache
|
||||
assert is_initialized()
|
||||
logger.info("Resetting cache at %s.", _cache._path)
|
||||
_cache = None
|
||||
|
@ -861,7 +861,7 @@ py_test(
|
||||
srcs = ["gfile_cache_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:compilation_cache",
|
||||
"//jax:gfile_cache",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
@ -873,8 +873,7 @@ jax_test(
|
||||
"tpu": ["nomsan"], # TODO(b/213388298): this test fails msan.
|
||||
},
|
||||
deps = [
|
||||
"//jax:compilation_cache",
|
||||
"//jax:experimental",
|
||||
"//jax:compilation_cache_internal",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -24,20 +24,21 @@ from unittest import mock, SkipTest
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import pjit
|
||||
import jax
|
||||
from jax import jit, lax, pmap
|
||||
import jax._src.test_util as jtu
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import compilation_cache as cc
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.config import (persistent_cache_min_compile_time_secs,
|
||||
raise_persistent_cache_errors)
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax.config import config
|
||||
from jax._src.config import (persistent_cache_min_compile_time_secs,
|
||||
raise_persistent_cache_errors)
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
@ -12,12 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax.experimental.compilation_cache.gfile_cache import GFileCache
|
||||
import jax._src.test_util as jtu
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax._src.gfile_cache import GFileCache
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
|
||||
class FileSystemCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_get_nonexistent_key(self):
|
||||
@ -57,29 +60,31 @@ class FileSystemCacheTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, r"key cannot be empty"):
|
||||
cache.get("")
|
||||
|
||||
|
||||
def test_threads(self):
|
||||
file_contents1 = "1" * (65536 + 1)
|
||||
file_contents2 = "2" * (65536 + 1)
|
||||
|
||||
def call_multiple_puts_and_gets(cache):
|
||||
for i in range(50):
|
||||
cache.put("foo", file_contents1.encode('utf-8').strip())
|
||||
cache.put("foo", file_contents2.encode('utf-8').strip())
|
||||
for _ in range(50):
|
||||
cache.put("foo", file_contents1.encode("utf-8").strip())
|
||||
cache.put("foo", file_contents2.encode("utf-8").strip())
|
||||
cache.get("foo")
|
||||
self.assertEqual(cache.get("foo"), file_contents2.encode('utf-8').strip())
|
||||
self.assertEqual(
|
||||
cache.get("foo"), file_contents2.encode("utf-8").strip()
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = GFileCache(tmpdir)
|
||||
threads = []
|
||||
for i in range(50):
|
||||
for _ in range(50):
|
||||
t = threading.Thread(target=call_multiple_puts_and_gets(cache))
|
||||
t.start()
|
||||
threads.append(t)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
self.assertEqual(cache.get("foo"), file_contents2.encode('utf-8').strip())
|
||||
self.assertEqual(cache.get("foo"), file_contents2.encode("utf-8").strip())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user