rocm_jax/tests/cache_key_test.py
Peter Hawkins efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00

399 lines
15 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
import re
import sys
import unittest
from typing import cast as type_cast
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax._src import cache_key
from jax._src import compiler
from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.mesh import Mesh
from jax._src.partition_spec import PartitionSpec as P
from jax._src.sharding_impls import NamedSharding
from jax._src.custom_partitioning import custom_partitioning
config.parse_flags_with_absl()
class CacheKeyTest(jtu.JaxTestCase):
def test_serialized_compile_options(self):
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
hash1 = self.get_hashed_value(
cache_key._hash_serialized_compile_options, compile_options
)
debug_options = compile_options.executable_build_options.debug_options
debug_options.xla_force_host_platform_device_count = 2
debug_options.xla_dump_to = "foo"
debug_options.xla_dump_hlo_module_re = "bar"
debug_options.xla_dump_hlo_pass_re = "baz"
debug_options.xla_dump_hlo_as_text = True
debug_options.xla_dump_hlo_as_proto = True
debug_options.xla_dump_hlo_as_dot = True
debug_options.xla_dump_hlo_as_url = True
debug_options.xla_dump_hlo_as_html = True
debug_options.xla_dump_fusion_visualization = True
debug_options.xla_dump_hlo_snapshots = True
debug_options.xla_dump_max_hlo_modules = True
debug_options.xla_dump_module_metadata = True
debug_options.xla_dump_compress_protos = True
debug_options.xla_dump_hlo_as_long_text = True
debug_options.xla_dump_disable_metadata = True
debug_options.xla_dump_hlo_pipeline_re = "xyzzy"
debug_options.xla_gpu_experimental_autotune_cache_mode = 2
hash2 = self.get_hashed_value(
cache_key._hash_serialized_compile_options, compile_options
)
self.assertEqual(hash1, hash2)
@jtu.skip_on_devices("cpu")
def test_hash_accelerator_devices(self):
devices = np.array([[jax.local_devices()[0]]])
dev_hash1 = self.get_hashed_value(cache_key._hash_devices, devices)
dev_hash2 = self.get_hashed_value(cache_key._hash_devices, devices)
self.assertEqual(dev_hash1, dev_hash2)
acc_hash1 = self.get_hashed_value(
cache_key._hash_accelerator_config, devices, xla_bridge.get_backend())
acc_hash2 = self.get_hashed_value(
cache_key._hash_accelerator_config, devices, xla_bridge.get_backend())
self.assertEqual(acc_hash1, acc_hash2)
def test_hash_platform(self):
hash1 = self.get_hashed_value(
cache_key._hash_platform, xla_bridge.get_backend()
)
hash2 = self.get_hashed_value(
cache_key._hash_platform, xla_bridge.get_backend()
)
self.assertEqual(hash1, hash2)
if xla_bridge.get_backend().platform != "cpu":
cpu_backend = xla_bridge.get_backend("cpu")
hash3 = self.get_hashed_value(cache_key._hash_platform, cpu_backend)
self.assertNotEqual(hash1, hash3)
def test_hash_string(self):
hash1 = self.get_hashed_value(cache_key._hash_string, "foo")
hash2 = self.get_hashed_value(cache_key._hash_string, "bar")
hash3 = self.get_hashed_value(cache_key._hash_string, "bar")
self.assertEqual(hash2, hash3)
self.assertNotEqual(hash1, hash2)
def test_same_key(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
self.assertEqual(
cache_key.get(computation, devices, compile_options, backend),
cache_key.get(computation, devices, compile_options, backend),
)
def test_different_key(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options_not_filled = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
compile_options_filled = self.filled_compile_options()
backend = xla_bridge.get_backend()
self.assertNotEqual(
cache_key.get(
computation, devices, compile_options_not_filled, backend
),
cache_key.get(computation, devices, compile_options_filled, backend),
)
def test_custom_hook(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
original_custom_hook = cache_key.custom_hook
cache_key.custom_hook = lambda: "hook1"
key1 = cache_key.get(computation, devices, compile_options, backend)
cache_key.custom_hook = lambda: "hook2"
key2 = cache_key.get(computation, devices, compile_options, backend)
cache_key.custom_hook = original_custom_hook
self.assertNotEqual(key1, key2)
def test_different_computations(self):
computation1 = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
computation2 = jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
self.assertNotEqual(
cache_key.get(computation1, devices, compile_options, backend),
cache_key.get(computation2, devices, compile_options, backend),
)
def test_custom_partitioning_ptr_removal(self):
def _partition(mesh, arg_shapes, result_shape):
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
result_shardings = NamedSharding(mesh, arg_shapes[0].sharding.spec)
return mesh, jax.numpy.add, result_shardings, arg_shardings
def _infer_sharding_from_operands(mesh, arg_shapes, result_shape):
return NamedSharding(mesh, arg_shapes[0].sharding.spec)
@custom_partitioning
def _cp_add(x, y):
return jax.numpy.add(x, y)
_cp_add.def_partition(
infer_sharding_from_operands=_infer_sharding_from_operands,
partition=_partition)
devices = np.asarray(jax.devices())
with Mesh(devices, ('x',)) as m:
computation = jax.jit(
_cp_add,
in_shardings=(NamedSharding(m, P('x')),
NamedSharding(m, P('x'))),
out_shardings=NamedSharding(m, P('x'))
).lower(
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
).compiler_ir()
pattern = (
r'stablehlo\.custom_call @CustomSPMDPartitioning\('
r'(.*?)\) \{'
r'(.*?backend_config\s*=\s*"([^"]*)".*?)'
r'\}'
)
with computation.context:
updated_module = cache_key._remove_callbacks(
type_cast(ir.Module, computation.operation.clone()),
ignore_callbacks=cache_key.IgnoreCallbacks.ALL,
)
bcs = [
match[2]
for match in re.findall(pattern, str(updated_module), re.DOTALL)
]
for bc in bcs:
self.assertEqual(bc, "REMOVED")
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
hash_without_callback_ptrs = cache_key.get(
computation,
devices,
compile_options,
backend,
ignore_callbacks=cache_key.IgnoreCallbacks.CUSTOM_PARTITIONING,
)
expected_hash = cache_key.get(
updated_module, devices, compile_options, backend
)
self.assertEqual(expected_hash, hash_without_callback_ptrs)
@jtu.skip_on_devices("cpu")
def test_host_callbacks_ptrs_removed(self):
def _host_callback(x, y):
jax.debug.print("x={x[0]} y={y[0]}", x=x, y=y)
computation = (
jax.jit(_host_callback)
.lower(
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
)
.compiler_ir()
)
pattern = r'(.*?backend_config\s*=\s*"([^"]*)".*?)'
with computation.context:
updated_module = cache_key._remove_callbacks(
type_cast(ir.Module, computation.operation.clone()),
ignore_callbacks=cache_key.IgnoreCallbacks.ALL,
)
bcs = [
match[1]
for match in re.findall(pattern, str(updated_module), re.DOTALL)
]
for bc in bcs:
self.assertEqual(bc, "REMOVED")
def test_different_device_assignment(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options_1 = compiler.get_compile_options(
num_replicas=1, num_partitions=1, device_assignment=np.array([[0]])
)
compile_options_2 = compiler.get_compile_options(
num_replicas=1, num_partitions=1, device_assignment=np.array([[1]])
)
backend = xla_bridge.get_backend()
hash_1 = cache_key.get(computation, devices, compile_options_1, backend)
hash_2 = cache_key.get(computation, devices, compile_options_2, backend)
if backend.platform == "gpu":
self.assertEqual(hash_1, hash_2)
else:
self.assertNotEqual(hash_1, hash_2)
@parameterized.parameters([False, True])
@jtu.thread_unsafe_test() # env vars are not thread-safe
def test_identical_computations_different_metadata(self, include_metadata):
f = lambda x, y: lax.mul(lax.add(x, y), 2)
g = lambda x, y: lax.mul(lax.add(x, y), 2)
assert id(f) != id(g)
computation1 = jax.jit(f).lower(1, 1).compiler_ir()
computation2 = jax.jit(g).lower(2, 3).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
with config.compilation_cache_include_metadata_in_key(include_metadata):
key1 = cache_key.get(computation1, devices, compile_options, backend)
key2 = cache_key.get(computation2, devices, compile_options, backend)
self.assertEqual(include_metadata, key1 != key2)
@jtu.thread_unsafe_test() # env vars are not thread-safe
def test_xla_flags(self):
if jtu.is_device_tpu(version=4):
raise unittest.SkipTest("TODO(b/240151176)")
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
orig_xla_flags = os.getenv("XLA_FLAGS")
orig_argv = sys.argv
try:
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
key1 = cache_key.get(computation, devices, compile_options, backend)
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1"
key2 = cache_key.get(computation, devices, compile_options, backend)
self.assertNotEqual(key1, key2)
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
key3 = cache_key.get(computation, devices, compile_options, backend)
self.assertEqual(key1, key3)
# Test flag in _xla_flags_to_exclude_from_cache_key
os.environ["XLA_FLAGS"] = (
"--xla_gpu_autotune_level=0 --xla_force_host_platform_device_count=8"
)
key4 = cache_key.get(computation, devices, compile_options, backend)
self.assertEqual(key1, key4)
# Test flags given on command line
del os.environ["XLA_FLAGS"]
sys.argv.append("--xla_gpu_autotune_level=0")
key5 = cache_key.get(computation, devices, compile_options, backend)
self.assertEqual(key1, key5)
sys.argv.append("--xla_force_host_platform_device_count=8")
self.assertEqual(key1, key5)
finally:
if orig_xla_flags is not None:
os.environ["XLA_FLAGS"] = orig_xla_flags
elif os.getenv("XLA_FLAGS") is not None:
del os.environ["XLA_FLAGS"]
sys.argv = orig_argv
@jtu.thread_unsafe_test() # env vars are not thread-safe
def test_libtpu_init_args(self):
if jtu.is_device_tpu(version=4):
raise unittest.SkipTest("TODO(b/240151176)")
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
orig_libtpu_init_args = os.getenv("LIBTPU_INIT_ARGS")
orig_argv = sys.argv
try:
os.environ["LIBTPU_INIT_ARGS"] = (
"--xla_spmd_threshold_for_windowed_einsum_mib=0"
)
key1 = cache_key.get(computation, devices, compile_options, backend)
os.environ["LIBTPU_INIT_ARGS"] = (
"--xla_spmd_threshold_for_windowed_einsum_mib=1"
)
key2 = cache_key.get(computation, devices, compile_options, backend)
self.assertNotEqual(key1, key2)
finally:
if orig_libtpu_init_args is not None:
os.environ["LIBTPU_INIT_ARGS"] = orig_libtpu_init_args
elif os.getenv("LIBTPU_INIT_ARGS") is not None:
del os.environ["LIBTPU_INIT_ARGS"]
sys.argv = orig_argv
def filled_compile_options(self):
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = 1
compile_options.num_partitions = 1
shape = xla_client.Shape.array_shape(np.dtype(np.float32), [2])
shape_array = [shape, shape]
compile_options.argument_layouts = shape_array
compile_options.executable_build_options.result_layout = shape
device_assignment = xla_client.DeviceAssignment.create(
np.arange(4).reshape(2, 2)
)
compile_options.device_assignment = device_assignment
compile_options.executable_build_options.device_assignment = (
device_assignment
)
compile_options.executable_build_options.fdo_profile = b"test_profile"
return compile_options
def get_hashed_value(
self, hash_function, hash_function_input1, hash_function_input2=None):
hash_obj = hashlib.sha256()
if hash_function_input2 is not None:
hash_function(hash_obj, hash_function_input1, hash_function_input2)
else:
hash_function(hash_obj, hash_function_input1)
return hash_obj.digest().hex()
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())