mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Include the device_kind in the compilation cache key.
PiperOrigin-RevId: 525726898
This commit is contained in:
parent
87c328864b
commit
1d63d9b833
@ -257,7 +257,7 @@ pytype_strict_library(
|
||||
":gfile_cache",
|
||||
":path",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("zstandard"),
|
||||
] + py_deps("numpy") + py_deps("zstandard"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
|
@ -21,6 +21,8 @@ import sys
|
||||
from typing import Any, List, Optional
|
||||
import zlib
|
||||
|
||||
import numpy as np
|
||||
|
||||
# If zstandard is installed, we use zstd compression, otherwise we use zlib.
|
||||
try:
|
||||
import zstandard
|
||||
@ -136,7 +138,8 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
|
||||
)
|
||||
|
||||
|
||||
def get_cache_key(module: ir.Module, compile_options, backend) -> str:
|
||||
def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
|
||||
backend) -> str:
|
||||
"""Creates a hashed string to use as a key to the compilation cache.
|
||||
|
||||
get_cache_key takes in the MLIR module and compile_options of a program
|
||||
@ -148,6 +151,7 @@ def get_cache_key(module: ir.Module, compile_options, backend) -> str:
|
||||
"""
|
||||
entries = [
|
||||
("computation", lambda hash_obj: _hash_computation(hash_obj, module)),
|
||||
("devices", lambda hash_obj: _hash_devices(hash_obj, devices)),
|
||||
("compile_options",
|
||||
lambda hash_obj: _hash_compile_options(hash_obj, compile_options)),
|
||||
("jax_lib version",
|
||||
@ -191,6 +195,9 @@ def _hash_computation(hash_obj, module):
|
||||
canonical_ir = _canonicalize_ir(module)
|
||||
hash_obj.update(canonical_ir)
|
||||
|
||||
def _hash_devices(hash_obj, devices: np.ndarray) -> None:
|
||||
for device in devices.flat:
|
||||
_hash_string(hash_obj, device.device_kind)
|
||||
|
||||
def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
if xla_extension_version >= 145:
|
||||
|
@ -473,8 +473,8 @@ def _dump_ir_to_file(name: str, ir: str):
|
||||
name.write_text(ir)
|
||||
|
||||
|
||||
def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
||||
host_callbacks):
|
||||
def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,
|
||||
compile_options, host_callbacks):
|
||||
sym_name = computation.operation.attributes['sym_name']
|
||||
module_name = ir.StringAttr(sym_name).value
|
||||
|
||||
@ -495,7 +495,7 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
||||
host_callbacks)
|
||||
|
||||
cache_key = compilation_cache.get_cache_key(
|
||||
computation, compile_options, backend)
|
||||
computation, devices, compile_options, backend)
|
||||
|
||||
cached_executable = _cache_read(module_name, cache_key, compile_options,
|
||||
backend)
|
||||
|
@ -1092,7 +1092,8 @@ class UnloadedPmapExecutable:
|
||||
f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec",
|
||||
event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
compiled = dispatch.compile_or_get_cached(
|
||||
pci.backend, xla_computation, compile_options, host_callbacks)
|
||||
pci.backend, xla_computation, device_assignment, compile_options,
|
||||
host_callbacks)
|
||||
|
||||
return UnloadedPmapExecutable(
|
||||
compiled=compiled,
|
||||
@ -2568,7 +2569,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
"in {elapsed_time} sec",
|
||||
event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
xla_executable = dispatch.compile_or_get_cached(
|
||||
backend, computation, compile_options, host_callbacks)
|
||||
backend, computation, dev, compile_options, host_callbacks)
|
||||
return xla_executable, compile_options
|
||||
|
||||
|
||||
|
@ -164,37 +164,41 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_same_hash_key(self):
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
self.assertEqual(
|
||||
cc.get_cache_key(computation, compile_options, backend),
|
||||
cc.get_cache_key(computation, compile_options, backend),
|
||||
cc.get_cache_key(computation, devices, compile_options, backend),
|
||||
cc.get_cache_key(computation, devices, compile_options, backend),
|
||||
)
|
||||
|
||||
def test_different_hash_key(self):
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options_not_filled = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
compile_options_filled = self.filled_compile_options()
|
||||
backend = xla_bridge.get_backend()
|
||||
self.assertNotEqual(
|
||||
cc.get_cache_key(computation, compile_options_not_filled, backend),
|
||||
cc.get_cache_key(computation, compile_options_filled, backend),
|
||||
cc.get_cache_key(computation, devices, compile_options_not_filled,
|
||||
backend),
|
||||
cc.get_cache_key(computation, devices, compile_options_filled, backend),
|
||||
)
|
||||
|
||||
def test_different_computations(self):
|
||||
computation1 = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
computation2 = jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
self.assertNotEqual(
|
||||
cc.get_cache_key(computation1, compile_options, backend),
|
||||
cc.get_cache_key(computation2, compile_options, backend),
|
||||
cc.get_cache_key(computation1, devices, compile_options, backend),
|
||||
cc.get_cache_key(computation2, devices, compile_options, backend),
|
||||
)
|
||||
|
||||
@unittest.skipIf(jax._src.lib.version < (0, 4, 9),
|
||||
@ -206,13 +210,14 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
assert id(f) != id(g)
|
||||
computation1 = jax.jit(f).lower(1, 1).compiler_ir()
|
||||
computation2 = jax.jit(g).lower(2, 3).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
with compilation_cache_include_metadata_in_key(include_metadata):
|
||||
key1 = cc.get_cache_key(computation1, compile_options, backend)
|
||||
key2 = cc.get_cache_key(computation2, compile_options, backend)
|
||||
key1 = cc.get_cache_key(computation1, devices, compile_options, backend)
|
||||
key2 = cc.get_cache_key(computation2, devices, compile_options, backend)
|
||||
self.assertEqual(include_metadata, key1 != key2)
|
||||
|
||||
def test_xla_flags(self):
|
||||
@ -220,6 +225,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest("TODO(b/240151176)")
|
||||
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
@ -229,26 +235,26 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
orig_argv = sys.argv
|
||||
try:
|
||||
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
|
||||
key1 = cc.get_cache_key(computation, compile_options, backend)
|
||||
key1 = cc.get_cache_key(computation, devices, compile_options, backend)
|
||||
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1"
|
||||
key2 = cc.get_cache_key(computation, compile_options, backend)
|
||||
key2 = cc.get_cache_key(computation, devices, compile_options, backend)
|
||||
self.assertNotEqual(key1, key2)
|
||||
|
||||
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
|
||||
key3 = cc.get_cache_key(computation, compile_options, backend)
|
||||
key3 = cc.get_cache_key(computation, devices, compile_options, backend)
|
||||
self.assertEqual(key1, key3)
|
||||
|
||||
# Test flag in _xla_flags_to_exclude_from_cache_key
|
||||
os.environ["XLA_FLAGS"] = (
|
||||
"--xla_gpu_autotune_level=0 --xla_force_host_platform_device_count=8"
|
||||
)
|
||||
key4 = cc.get_cache_key(computation, compile_options, backend)
|
||||
key4 = cc.get_cache_key(computation, devices, compile_options, backend)
|
||||
self.assertEqual(key1, key4)
|
||||
|
||||
# Test flags given on command line
|
||||
del os.environ["XLA_FLAGS"]
|
||||
sys.argv.append("--xla_gpu_autotune_level=0")
|
||||
key5 = cc.get_cache_key(computation, compile_options, backend)
|
||||
key5 = cc.get_cache_key(computation, devices, compile_options, backend)
|
||||
self.assertEqual(key1, key5)
|
||||
sys.argv.append("--xla_force_host_platform_device_count=8")
|
||||
self.assertEqual(key1, key5)
|
||||
@ -264,11 +270,12 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
key = cc.get_cache_key(computation, compile_options, backend)
|
||||
key = cc.get_cache_key(computation, devices, compile_options, backend)
|
||||
self.assertEqual(
|
||||
cc.get_executable(key, compile_options, backend), None
|
||||
)
|
||||
@ -299,13 +306,13 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
.lower(np.int32(1), np.int32(1))
|
||||
.compiler_ir()
|
||||
)
|
||||
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
executable = backend.compile(str(computation), compile_options)
|
||||
key = cc.get_cache_key(computation, compile_options, backend)
|
||||
key = cc.get_cache_key(computation, devices, compile_options, backend)
|
||||
cc.put_executable(key, "alambda", executable, backend)
|
||||
deserialized_executable = cc.get_executable(key, compile_options, backend)
|
||||
inputs_to_executable = (
|
||||
|
Loading…
x
Reference in New Issue
Block a user