mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

Fixes #12582. Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning. Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
368 lines
16 KiB
Python
368 lines
16 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
import hashlib
|
|
import os
|
|
import random
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from unittest import mock, SkipTest
|
|
import warnings
|
|
|
|
from absl.testing import absltest
|
|
from jax.experimental import PartitionSpec as P
|
|
from jax.experimental.compilation_cache import compilation_cache as cc
|
|
from jax.experimental.maps import xmap
|
|
from jax.experimental.pjit import pjit
|
|
import jax
|
|
from jax import jit, lax, pmap
|
|
from jax._src.util import prod
|
|
import jax._src.test_util as jtu
|
|
from jax._src.lib import xla_bridge
|
|
from jax._src.lib import xla_client
|
|
import numpy as np
|
|
|
|
from jax.config import config
|
|
from jax._src.config import raise_persistent_cache_errors
|
|
|
|
config.parse_flags_with_absl()
|
|
FLAGS = config.FLAGS
|
|
|
|
@jtu.with_config(jax_raise_persistent_cache_errors=True)
|
|
class CompilationCacheTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
supported_platforms = ["tpu"]
|
|
if "--xla_gpu_enable_xla_runtime_executable=true" in os.environ.get("XLA_FLAGS", ""):
|
|
supported_platforms.append("gpu")
|
|
if jtu.device_under_test() not in supported_platforms:
|
|
raise SkipTest("serialize executable only works on " +
|
|
",".join(supported_platforms))
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
cc._cache = None
|
|
|
|
def test_compile_options(self):
|
|
compile_options_not_filled = xla_bridge.get_compile_options(
|
|
num_replicas=1, num_partitions=1)
|
|
compile_options_filled = self.filled_compile_options()
|
|
filled_hash1 = self.get_hashed_value(cc._hash_compile_options, compile_options_filled)
|
|
filled_hash2 = self.get_hashed_value(cc._hash_compile_options, compile_options_filled)
|
|
not_filled_hash3 = self.get_hashed_value(cc._hash_compile_options, compile_options_not_filled)
|
|
self.assertEqual(filled_hash1, filled_hash2)
|
|
self.assertNotEqual(filled_hash1, not_filled_hash3)
|
|
|
|
def test_executable_build_options(self):
|
|
compile_options_not_filled = xla_bridge.get_compile_options(
|
|
num_replicas=1, num_partitions=1)
|
|
compile_options_filled = self.filled_compile_options()
|
|
filled_hash1 = self.get_hashed_value(cc._hash_executable_build_options,
|
|
compile_options_filled.executable_build_options)
|
|
filled_hash2 = self.get_hashed_value(cc._hash_executable_build_options,
|
|
compile_options_filled.executable_build_options)
|
|
not_filled_hash3 = self.get_hashed_value(cc._hash_executable_build_options,
|
|
compile_options_not_filled.executable_build_options)
|
|
self.assertEqual(filled_hash1, filled_hash2)
|
|
self.assertNotEqual(filled_hash1, not_filled_hash3)
|
|
|
|
def test_debug_options(self):
|
|
compile_options = xla_bridge.get_compile_options(
|
|
num_replicas=1, num_partitions=1)
|
|
hash1 = self.get_hashed_value(cc._hash_debug_options,
|
|
compile_options.executable_build_options.debug_options)
|
|
hash2 = self.get_hashed_value(cc._hash_debug_options,
|
|
compile_options.executable_build_options.debug_options)
|
|
self.assertEqual(hash1, hash2)
|
|
new_debug_options = self.create_new_debug_options(compile_options.executable_build_options.debug_options)
|
|
hash3 = self.get_hashed_value(cc._hash_debug_options, new_debug_options)
|
|
self.assertNotEqual(hash1, hash3)
|
|
|
|
def test_hash_platform(self):
|
|
hash1 = self.get_hashed_value(cc._hash_platform, xla_bridge.get_backend())
|
|
hash2 = self.get_hashed_value(cc._hash_platform, xla_bridge.get_backend())
|
|
self.assertEqual(hash1, hash2)
|
|
if xla_bridge.get_backend().platform != "cpu":
|
|
cpu_backend = xla_bridge.get_backend("cpu")
|
|
hash3 = self.get_hashed_value(cc._hash_platform, cpu_backend)
|
|
self.assertNotEqual(hash1, hash3)
|
|
|
|
def test_hash_int(self):
|
|
hash1 = self.get_hashed_value(cc._hash_int, 90)
|
|
hash2 = self.get_hashed_value(cc._hash_int, 8)
|
|
hash3 = self.get_hashed_value(cc._hash_int, 8)
|
|
self.assertEqual(hash2, hash3)
|
|
self.assertNotEqual(hash1, hash2)
|
|
|
|
def test_hash_bool(self):
|
|
hash1 = self.get_hashed_value(cc._hash_bool, False)
|
|
hash2 = self.get_hashed_value(cc._hash_bool, True)
|
|
hash3 = self.get_hashed_value(cc._hash_bool, True)
|
|
self.assertEqual(hash2, hash3)
|
|
self.assertNotEqual(hash1, hash2)
|
|
|
|
def test_hash_string(self):
|
|
hash1 = self.get_hashed_value(cc._hash_string, "foo")
|
|
hash2 = self.get_hashed_value(cc._hash_string, "bar")
|
|
hash3 = self.get_hashed_value(cc._hash_string, "bar")
|
|
self.assertEqual(hash2, hash3)
|
|
self.assertNotEqual(hash1, hash2)
|
|
|
|
def test_same_hash_key(self):
|
|
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
|
compile_options = xla_bridge.get_compile_options(
|
|
num_replicas=1, num_partitions=1)
|
|
backend = xla_bridge.get_backend()
|
|
self.assertEqual(cc.get_cache_key(computation, compile_options, backend),
|
|
cc.get_cache_key(computation, compile_options, backend))
|
|
|
|
def test_different_hash_key(self):
|
|
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
|
compile_options_not_filled = xla_bridge.get_compile_options(
|
|
num_replicas=1, num_partitions=1)
|
|
compile_options_filled = self.filled_compile_options()
|
|
backend = xla_bridge.get_backend()
|
|
self.assertNotEqual(cc.get_cache_key(computation, compile_options_not_filled, backend),
|
|
cc.get_cache_key(computation, compile_options_filled, backend))
|
|
|
|
def test_different_computations(self):
|
|
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
|
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
|
|
compile_options = xla_bridge.get_compile_options(
|
|
num_replicas=1, num_partitions=1)
|
|
backend = xla_bridge.get_backend()
|
|
self.assertNotEqual(cc.get_cache_key(computation1, compile_options, backend),
|
|
cc.get_cache_key(computation2, compile_options, backend))
|
|
|
|
def test_xla_flags(self):
|
|
if jtu.is_device_tpu_v4():
|
|
raise unittest.SkipTest("TODO(b/240151176)")
|
|
|
|
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
|
compile_options = xla_bridge.get_compile_options(
|
|
num_replicas=1, num_partitions=1)
|
|
backend = xla_bridge.get_backend()
|
|
|
|
orig_xla_flags = os.getenv("XLA_FLAGS")
|
|
orig_argv = sys.argv
|
|
try:
|
|
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
|
|
key1 = cc.get_cache_key(computation, compile_options, backend)
|
|
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1"
|
|
key2 = cc.get_cache_key(computation, compile_options, backend)
|
|
self.assertNotEqual(key1, key2)
|
|
|
|
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
|
|
key3 = cc.get_cache_key(computation, compile_options, backend)
|
|
self.assertEqual(key1, key3)
|
|
|
|
# Test flag in _xla_flags_to_exclude_from_cache_key
|
|
os.environ["XLA_FLAGS"] = (
|
|
"--xla_gpu_autotune_level=0 --xla_force_host_platform_device_count=8")
|
|
key4 = cc.get_cache_key(computation, compile_options, backend)
|
|
self.assertEqual(key1, key4)
|
|
|
|
# Test flags given on command line
|
|
del os.environ["XLA_FLAGS"]
|
|
sys.argv.append("--xla_gpu_autotune_level=0")
|
|
key5 = cc.get_cache_key(computation, compile_options, backend)
|
|
self.assertEqual(key1, key5)
|
|
sys.argv.append("--xla_force_host_platform_device_count=8")
|
|
self.assertEqual(key1, key5)
|
|
|
|
finally:
|
|
if orig_xla_flags is not None:
|
|
os.environ["XLA_FLAGS"] = orig_xla_flags
|
|
elif os.getenv("XLA_FLAGS") is not None:
|
|
del os.environ["XLA_FLAGS"]
|
|
sys.argv = orig_argv
|
|
|
|
def test_get_no_executable(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cc.initialize_cache(tmpdir)
|
|
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
|
compile_options = xla_bridge.get_compile_options(
|
|
num_replicas=1, num_partitions=1)
|
|
backend = xla_bridge.get_backend()
|
|
self.assertEqual(cc.get_executable(computation, compile_options, backend), None)
|
|
|
|
def test_diff_executables(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cc.initialize_cache(tmpdir)
|
|
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
|
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
|
|
compile_options = xla_bridge.get_compile_options(
|
|
num_replicas=1, num_partitions=1)
|
|
backend = xla_bridge.get_backend()
|
|
executable1 = backend.compile(computation1, compile_options)
|
|
executable2 = backend.compile(computation2, compile_options)
|
|
cc.put_executable("computation1", computation1, compile_options,
|
|
executable1, backend)
|
|
cc.put_executable("computation2", computation2, compile_options,
|
|
executable2, backend)
|
|
self.assertNotEqual(cc.get_executable(computation1, compile_options, backend),
|
|
cc.get_executable(computation2, compile_options, backend))
|
|
|
|
def test_put_executable(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cc.initialize_cache(tmpdir)
|
|
computation = jax.xla_computation(lambda x, y: x + y)(np.int32(1),
|
|
np.int32(1))
|
|
compile_options = xla_bridge.get_compile_options(
|
|
num_replicas=1, num_partitions=1)
|
|
backend = xla_bridge.get_backend()
|
|
executable = backend.compile(computation, compile_options)
|
|
cc.put_executable("alambda", computation, compile_options, executable,
|
|
backend)
|
|
deserialized_executable = cc.get_executable(computation, compile_options, backend)
|
|
inputs_to_executable = (np.array(1, dtype=np.int32), np.array(2, dtype=np.int32))
|
|
expected = xla_client.execute_with_python_values(executable, inputs_to_executable, backend)
|
|
actual = xla_client.execute_with_python_values(deserialized_executable, inputs_to_executable, backend)
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_pmap(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cc.initialize_cache(tmpdir)
|
|
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
|
|
x = np.arange(jax.device_count(), dtype=np.int64)
|
|
f(x)
|
|
files_in_directory = len(os.listdir(tmpdir))
|
|
self.assertEqual(files_in_directory, 1)
|
|
x = np.arange(jax.device_count(), dtype=np.float32)
|
|
f(x)
|
|
files_in_directory = len(os.listdir(tmpdir))
|
|
self.assertEqual(files_in_directory, 2)
|
|
#TODO: create a test for calling pmap with the same input more than once
|
|
|
|
def test_jit(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cc.initialize_cache(tmpdir)
|
|
f = jit(lambda x: x*x)
|
|
f(1)
|
|
files_in_directory = len(os.listdir(tmpdir))
|
|
self.assertEqual(files_in_directory, 1)
|
|
f(1.0)
|
|
files_in_directory = len(os.listdir(tmpdir))
|
|
self.assertEqual(files_in_directory, 2)
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def test_pjit(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cc.initialize_cache(tmpdir)
|
|
@partial(pjit,
|
|
in_axis_resources=(P('x'), P('x')),
|
|
out_axis_resources=None)
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
shape = (8, 8)
|
|
x = np.arange(prod(shape), dtype=np.int64).reshape(shape)
|
|
f(x, x + 1)
|
|
files_in_directory = len(os.listdir(tmpdir))
|
|
self.assertEqual(files_in_directory, 1)
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
f(x, x + 1)
|
|
files_in_directory = len(os.listdir(tmpdir))
|
|
self.assertEqual(files_in_directory, 2)
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def test_xmap(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cc.initialize_cache(tmpdir)
|
|
def f(x):
|
|
return x * 2
|
|
devices = np.array(jax.local_devices()[:2])
|
|
if devices.size < 2:
|
|
raise SkipTest("Test requires 2 devices")
|
|
x = np.arange(8, dtype=np.int64).reshape((2, 2, 2))
|
|
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
|
axis_resources={'a': 'x'})(x)
|
|
files_in_directory = len(os.listdir(tmpdir))
|
|
self.assertEqual(files_in_directory, 1)
|
|
x = np.arange(8, dtype=np.float32).reshape((2, 2, 2))
|
|
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
|
axis_resources={'a': 'x'})(x)
|
|
files_in_directory = len(os.listdir(tmpdir))
|
|
self.assertEqual(files_in_directory, 2)
|
|
|
|
def test_cache_write_warning(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cc.initialize_cache(tmpdir)
|
|
f = jit(lambda x: x*x)
|
|
|
|
with raise_persistent_cache_errors(False), \
|
|
mock.patch.object(cc._cache.__class__, 'put') as mock_put, \
|
|
warnings.catch_warnings(record=True) as w:
|
|
mock_put.side_effect = RuntimeError("test error")
|
|
self.assertEqual(f(2), 4)
|
|
self.assertLen(w, 1)
|
|
self.assertIn(
|
|
"Error writing persistent compilation cache entry "
|
|
"for 'jit__lambda_': RuntimeError: test error",
|
|
str(w[0].message))
|
|
|
|
def test_cache_read_warning(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cc.initialize_cache(tmpdir)
|
|
f = jit(lambda x: x*x)
|
|
|
|
with raise_persistent_cache_errors(False), \
|
|
mock.patch.object(cc._cache.__class__, 'get') as mock_get, \
|
|
warnings.catch_warnings(record=True) as w:
|
|
mock_get.side_effect = RuntimeError("test error")
|
|
self.assertEqual(f(2), 4)
|
|
self.assertLen(w, 1)
|
|
self.assertIn(
|
|
"Error reading persistent compilation cache entry "
|
|
"for 'jit__lambda_': RuntimeError: test error",
|
|
str(w[0].message))
|
|
|
|
def create_new_debug_options(self, debug_options_obj):
|
|
debug_options_obj.xla_cpu_enable_fast_math = False
|
|
debug_options_obj.xla_cpu_fast_math_honor_infs = False
|
|
debug_options_obj.xla_cpu_fast_math_honor_nans = False
|
|
debug_options_obj.xla_cpu_fast_math_honor_division = False
|
|
debug_options_obj.xla_cpu_fast_math_honor_functions = False
|
|
debug_options_obj.xla_gpu_enable_fast_min_max = False
|
|
debug_options_obj.xla_backend_optimization_level = random.randint(0, 10)
|
|
debug_options_obj.xla_cpu_enable_xprof_traceme = False
|
|
debug_options_obj.xla_llvm_disable_expensive_passes = False
|
|
debug_options_obj.xla_test_all_input_layouts = False
|
|
return debug_options_obj
|
|
|
|
def filled_compile_options(self):
|
|
compile_options = xla_client.CompileOptions()
|
|
compile_options.num_replicas = 1
|
|
compile_options.num_partitions = 1
|
|
shape = xla_client.Shape.array_shape(np.dtype(np.float32), [2])
|
|
shape_array = [shape, shape]
|
|
compile_options.argument_layouts = shape_array
|
|
compile_options.executable_build_options.result_layout = shape
|
|
|
|
device_assignment = xla_client.DeviceAssignment.create(np.ndarray(shape=(2,2)))
|
|
compile_options.device_assignment = device_assignment
|
|
compile_options.executable_build_options.device_assignment = device_assignment
|
|
return compile_options
|
|
|
|
def get_hashed_value(self, hash_function, hash_function_input):
|
|
hash_obj = hashlib.sha256()
|
|
hash_function(hash_obj, hash_function_input)
|
|
return hash_obj.digest().hex()
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|