rocm_jax/tests/compilation_cache_test.py

270 lines
9.4 KiB
Python
Raw Normal View History

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2021-07-20 16:16:21 +00:00
from functools import partial
import math
import os
import tempfile
from unittest import mock
from unittest import SkipTest
import warnings
from absl.testing import absltest
import jax
from jax import config
from jax import jit
from jax import lax
from jax import pmap
from jax._src import compilation_cache as cc
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.config import persistent_cache_min_compile_time_secs
from jax._src.config import raise_persistent_cache_errors
from jax._src.lib import xla_client
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
import numpy as np
config.parse_flags_with_absl()
FLAGS = config.FLAGS
FAKE_COMPILE_TIME = 10
@jtu.with_config(
jax_raise_persistent_cache_errors=True,
jax_persistent_cache_min_compile_time_secs=0,
)
class CompilationCacheTest(jtu.JaxTestCase):
2021-07-08 16:11:39 +00:00
def setUp(self):
super().setUp()
supported_platforms = ["tpu", "gpu"]
if "--xla_cpu_use_xla_runtime=true" in os.environ.get("XLA_FLAGS", ""):
supported_platforms.append("cpu")
if jtu.device_under_test() not in supported_platforms:
raise SkipTest(
"serialize executable only works on " + ",".join(supported_platforms)
)
2021-07-08 16:11:39 +00:00
# Reset cache if already initialized by JaxTestCase
if cc.is_initialized():
cc.reset_cache()
2021-07-19 21:53:27 +00:00
def tearDown(self):
if cc.is_initialized():
cc.reset_cache()
super().tearDown()
2021-07-19 21:53:27 +00:00
def test_get_no_executable(self):
2021-10-04 13:52:23 -07:00
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
key = cc.get_cache_key(computation, devices, compile_options, backend)
executable, compile_time = cc.get_executable_and_time(
key, compile_options, backend)
self.assertIsNone(executable)
self.assertIsNone(compile_time)
def test_diff_executables(self):
2021-10-04 13:52:23 -07:00
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
computation1 = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
computation2 = str(jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir())
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
2021-10-04 13:52:23 -07:00
executable1 = backend.compile(computation1, compile_options)
executable2 = backend.compile(computation2, compile_options)
cc.put_executable_and_time(
"key1", "computation1", executable1, backend, FAKE_COMPILE_TIME)
cc.put_executable_and_time(
"key2", "computation2", executable2, backend, FAKE_COMPILE_TIME)
self.assertNotEqual(
cc.get_executable_and_time("key1", compile_options, backend)[0],
cc.get_executable_and_time("key2", compile_options, backend)[0]
)
def test_put_executable(self):
2021-10-04 13:52:23 -07:00
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
computation = (
jax.jit(lambda x, y: x + y)
.lower(np.int32(1), np.int32(1))
.compiler_ir()
)
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
executable = backend.compile(str(computation), compile_options)
key = cc.get_cache_key(computation, devices, compile_options, backend)
cc.put_executable_and_time(
key, "alambda", executable, backend, FAKE_COMPILE_TIME)
executable_retrieved, compile_time_retrieved = cc.get_executable_and_time(
key, compile_options, backend)
inputs_to_executable = (
np.array(1, dtype=np.int32),
np.array(2, dtype=np.int32),
)
expected = xla_client.execute_with_python_values(
executable, inputs_to_executable, backend
)
actual = xla_client.execute_with_python_values(
executable_retrieved, inputs_to_executable, backend
)
2021-10-04 13:52:23 -07:00
self.assertEqual(expected, actual)
self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved)
2021-07-08 16:11:39 +00:00
def test_pmap(self):
2021-10-04 13:52:23 -07:00
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
f = pmap(lambda x: x - lax.psum(x, "i"), axis_name="i")
2021-10-04 13:52:23 -07:00
x = np.arange(jax.device_count(), dtype=np.int64)
f(x)
files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 1)
x = np.arange(jax.device_count(), dtype=np.float32)
f(x)
files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 2)
# TODO: create a test for calling pmap with the same input more than once
2021-07-08 16:11:39 +00:00
def test_jit(self):
2021-10-04 13:52:23 -07:00
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
f = jit(lambda x: x * x)
2021-10-04 13:52:23 -07:00
f(1)
files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 1)
f(1.0)
files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 2)
@jtu.with_mesh([("x", 2)])
2021-07-20 16:16:21 +00:00
def test_pjit(self):
2021-10-04 13:52:23 -07:00
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
@partial(pjit, in_shardings=(P("x"), P("x")), out_shardings=None)
2021-10-04 13:52:23 -07:00
def f(x, y):
return x + y
2021-07-20 16:16:21 +00:00
2021-10-04 13:52:23 -07:00
shape = (8, 8)
x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape)
2021-10-04 13:52:23 -07:00
f(x, x + 1)
files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 1)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
2021-10-04 13:52:23 -07:00
f(x, x + 1)
files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 2)
2021-07-20 16:16:21 +00:00
@jtu.with_mesh([("x", 2)])
2021-07-20 16:16:21 +00:00
def test_xmap(self):
2021-10-04 13:52:23 -07:00
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
2021-10-04 13:52:23 -07:00
def f(x):
return x * 2
2021-10-04 13:52:23 -07:00
devices = np.array(jax.local_devices()[:2])
if devices.size < 2:
raise SkipTest("Test requires 2 devices")
x = np.arange(8, dtype=np.int64).reshape((2, 2, 2))
xmap(
f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"}
)(x)
2021-10-04 13:52:23 -07:00
files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 1)
x = np.arange(8, dtype=np.float32).reshape((2, 2, 2))
xmap(
f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"}
)(x)
2021-10-04 13:52:23 -07:00
files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 2)
2021-07-20 16:16:21 +00:00
def test_cache_write_warning(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
f = jit(lambda x: x * x)
with raise_persistent_cache_errors(False), mock.patch.object(
cc._cache.__class__, "put"
) as mock_put, warnings.catch_warnings(record=True) as w:
mock_put.side_effect = RuntimeError("test error")
self.assertEqual(f(2), 4)
self.assertLen(w, 1)
self.assertIn(
(
"Error writing persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error"
),
str(w[0].message),
)
def test_cache_read_warning(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
f = jit(lambda x: x * x)
with raise_persistent_cache_errors(False), mock.patch.object(
cc._cache.__class__, "get"
) as mock_get, warnings.catch_warnings(record=True) as w:
mock_get.side_effect = RuntimeError("test error")
self.assertEqual(f(2), 4)
if len(w) > 1:
print("Warnings:", [str(w_) for w_ in w], flush=True)
self.assertLen(w, 1)
self.assertIn(
(
"Error reading persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error"
),
str(w[0].message),
)
Add new config `jax_persistent_cache_min_compile_time_secs`. This replaces `jax_persistent_cache_min_instruction_count` introduced in https://github.com/google/jax/pull/12798, since gating on the compile time seems strictly better than gating on the instruction count (except maybe that the instruction count is more deterministic, but I don't think that's a big deal). I defaulted to 1 second as the minimum threshold based on the same flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) numbers from name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new float config functionality.
2022-10-28 23:53:30 +00:00
def test_min_compile_time(self):
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
2
):
Add new config `jax_persistent_cache_min_instruction_count`. This can be used to limit the number of entries written to the persistent compilation cache. I defaulted to setting 6 as the minimum threshold based on running the flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) and logging the instruction counts and complilation time: name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new int config functionality. Fixes #12583
2022-10-13 23:14:49 +00:00
cc.initialize_cache(tmpdir)
Add new config `jax_persistent_cache_min_compile_time_secs`. This replaces `jax_persistent_cache_min_instruction_count` introduced in https://github.com/google/jax/pull/12798, since gating on the compile time seems strictly better than gating on the instruction count (except maybe that the instruction count is more deterministic, but I don't think that's a big deal). I defaulted to 1 second as the minimum threshold based on the same flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) numbers from name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new float config functionality.
2022-10-28 23:53:30 +00:00
# Mock time to progress in small intervals so compilation time is small.
with mock.patch("time.monotonic", side_effect=np.arange(0, 10, 0.1)):
Add new config `jax_persistent_cache_min_compile_time_secs`. This replaces `jax_persistent_cache_min_instruction_count` introduced in https://github.com/google/jax/pull/12798, since gating on the compile time seems strictly better than gating on the instruction count (except maybe that the instruction count is more deterministic, but I don't think that's a big deal). I defaulted to 1 second as the minimum threshold based on the same flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) numbers from name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new float config functionality.
2022-10-28 23:53:30 +00:00
jit(lambda x: x + 1)(1)
Add new config `jax_persistent_cache_min_instruction_count`. This can be used to limit the number of entries written to the persistent compilation cache. I defaulted to setting 6 as the minimum threshold based on running the flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) and logging the instruction counts and complilation time: name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new int config functionality. Fixes #12583
2022-10-13 23:14:49 +00:00
files_in_cache = len(os.listdir(tmpdir))
self.assertEqual(files_in_cache, 0)
Add new config `jax_persistent_cache_min_compile_time_secs`. This replaces `jax_persistent_cache_min_instruction_count` introduced in https://github.com/google/jax/pull/12798, since gating on the compile time seems strictly better than gating on the instruction count (except maybe that the instruction count is more deterministic, but I don't think that's a big deal). I defaulted to 1 second as the minimum threshold based on the same flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) numbers from name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new float config functionality.
2022-10-28 23:53:30 +00:00
# Mock time to progress in large intervals so compilation time is large.
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
jit(lambda x: x + 2)(1)
Add new config `jax_persistent_cache_min_instruction_count`. This can be used to limit the number of entries written to the persistent compilation cache. I defaulted to setting 6 as the minimum threshold based on running the flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) and logging the instruction counts and complilation time: name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new int config functionality. Fixes #12583
2022-10-13 23:14:49 +00:00
files_in_cache = len(os.listdir(tmpdir))
self.assertEqual(files_in_cache, 1)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())