Include the device_kind in the compilation cache key.

PiperOrigin-RevId: 525726898
This commit is contained in:
Peter Hawkins 2023-04-20 06:16:12 -07:00 committed by jax authors
parent 87c328864b
commit 1d63d9b833
5 changed files with 38 additions and 23 deletions

View File

@ -257,7 +257,7 @@ pytype_strict_library(
":gfile_cache",
":path",
"//jax/_src/lib",
] + py_deps("zstandard"),
] + py_deps("numpy") + py_deps("zstandard"),
)
pytype_strict_library(

View File

@ -21,6 +21,8 @@ import sys
from typing import Any, List, Optional
import zlib
import numpy as np
# If zstandard is installed, we use zstd compression, otherwise we use zlib.
try:
import zstandard
@ -136,7 +138,8 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
)
def get_cache_key(module: ir.Module, compile_options, backend) -> str:
def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
backend) -> str:
"""Creates a hashed string to use as a key to the compilation cache.
get_cache_key takes in the MLIR module and compile_options of a program
@ -148,6 +151,7 @@ def get_cache_key(module: ir.Module, compile_options, backend) -> str:
"""
entries = [
("computation", lambda hash_obj: _hash_computation(hash_obj, module)),
("devices", lambda hash_obj: _hash_devices(hash_obj, devices)),
("compile_options",
lambda hash_obj: _hash_compile_options(hash_obj, compile_options)),
("jax_lib version",
@ -191,6 +195,9 @@ def _hash_computation(hash_obj, module):
canonical_ir = _canonicalize_ir(module)
hash_obj.update(canonical_ir)
def _hash_devices(hash_obj, devices: np.ndarray) -> None:
for device in devices.flat:
_hash_string(hash_obj, device.device_kind)
def _hash_compile_options(hash_obj, compile_options_obj):
if xla_extension_version >= 145:

View File

@ -473,8 +473,8 @@ def _dump_ir_to_file(name: str, ir: str):
name.write_text(ir)
def compile_or_get_cached(backend, computation: ir.Module, compile_options,
host_callbacks):
def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,
compile_options, host_callbacks):
sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value
@ -495,7 +495,7 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
host_callbacks)
cache_key = compilation_cache.get_cache_key(
computation, compile_options, backend)
computation, devices, compile_options, backend)
cached_executable = _cache_read(module_name, cache_key, compile_options,
backend)

View File

@ -1092,7 +1092,8 @@ class UnloadedPmapExecutable:
f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec",
event=dispatch.BACKEND_COMPILE_EVENT):
compiled = dispatch.compile_or_get_cached(
pci.backend, xla_computation, compile_options, host_callbacks)
pci.backend, xla_computation, device_assignment, compile_options,
host_callbacks)
return UnloadedPmapExecutable(
compiled=compiled,
@ -2568,7 +2569,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
"in {elapsed_time} sec",
event=dispatch.BACKEND_COMPILE_EVENT):
xla_executable = dispatch.compile_or_get_cached(
backend, computation, compile_options, host_callbacks)
backend, computation, dev, compile_options, host_callbacks)
return xla_executable, compile_options

View File

@ -164,37 +164,41 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_same_hash_key(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
self.assertEqual(
cc.get_cache_key(computation, compile_options, backend),
cc.get_cache_key(computation, compile_options, backend),
cc.get_cache_key(computation, devices, compile_options, backend),
cc.get_cache_key(computation, devices, compile_options, backend),
)
def test_different_hash_key(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options_not_filled = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1
)
compile_options_filled = self.filled_compile_options()
backend = xla_bridge.get_backend()
self.assertNotEqual(
cc.get_cache_key(computation, compile_options_not_filled, backend),
cc.get_cache_key(computation, compile_options_filled, backend),
cc.get_cache_key(computation, devices, compile_options_not_filled,
backend),
cc.get_cache_key(computation, devices, compile_options_filled, backend),
)
def test_different_computations(self):
computation1 = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
computation2 = jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
self.assertNotEqual(
cc.get_cache_key(computation1, compile_options, backend),
cc.get_cache_key(computation2, compile_options, backend),
cc.get_cache_key(computation1, devices, compile_options, backend),
cc.get_cache_key(computation2, devices, compile_options, backend),
)
@unittest.skipIf(jax._src.lib.version < (0, 4, 9),
@ -206,13 +210,14 @@ class CompilationCacheTest(jtu.JaxTestCase):
assert id(f) != id(g)
computation1 = jax.jit(f).lower(1, 1).compiler_ir()
computation2 = jax.jit(g).lower(2, 3).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
with compilation_cache_include_metadata_in_key(include_metadata):
key1 = cc.get_cache_key(computation1, compile_options, backend)
key2 = cc.get_cache_key(computation2, compile_options, backend)
key1 = cc.get_cache_key(computation1, devices, compile_options, backend)
key2 = cc.get_cache_key(computation2, devices, compile_options, backend)
self.assertEqual(include_metadata, key1 != key2)
def test_xla_flags(self):
@ -220,6 +225,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
raise unittest.SkipTest("TODO(b/240151176)")
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1
)
@ -229,26 +235,26 @@ class CompilationCacheTest(jtu.JaxTestCase):
orig_argv = sys.argv
try:
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
key1 = cc.get_cache_key(computation, compile_options, backend)
key1 = cc.get_cache_key(computation, devices, compile_options, backend)
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1"
key2 = cc.get_cache_key(computation, compile_options, backend)
key2 = cc.get_cache_key(computation, devices, compile_options, backend)
self.assertNotEqual(key1, key2)
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
key3 = cc.get_cache_key(computation, compile_options, backend)
key3 = cc.get_cache_key(computation, devices, compile_options, backend)
self.assertEqual(key1, key3)
# Test flag in _xla_flags_to_exclude_from_cache_key
os.environ["XLA_FLAGS"] = (
"--xla_gpu_autotune_level=0 --xla_force_host_platform_device_count=8"
)
key4 = cc.get_cache_key(computation, compile_options, backend)
key4 = cc.get_cache_key(computation, devices, compile_options, backend)
self.assertEqual(key1, key4)
# Test flags given on command line
del os.environ["XLA_FLAGS"]
sys.argv.append("--xla_gpu_autotune_level=0")
key5 = cc.get_cache_key(computation, compile_options, backend)
key5 = cc.get_cache_key(computation, devices, compile_options, backend)
self.assertEqual(key1, key5)
sys.argv.append("--xla_force_host_platform_device_count=8")
self.assertEqual(key1, key5)
@ -264,11 +270,12 @@ class CompilationCacheTest(jtu.JaxTestCase):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
key = cc.get_cache_key(computation, compile_options, backend)
key = cc.get_cache_key(computation, devices, compile_options, backend)
self.assertEqual(
cc.get_executable(key, compile_options, backend), None
)
@ -299,13 +306,13 @@ class CompilationCacheTest(jtu.JaxTestCase):
.lower(np.int32(1), np.int32(1))
.compiler_ir()
)
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
executable = backend.compile(str(computation), compile_options)
key = cc.get_cache_key(computation, compile_options, backend)
key = cc.get_cache_key(computation, devices, compile_options, backend)
cc.put_executable(key, "alambda", executable, backend)
deserialized_executable = cc.get_executable(key, compile_options, backend)
inputs_to_executable = (