Fix MSAN errors in cache_key_test

The device_assignment array was never initialized, causing MSAN errors.
Replacing it with np.arange fixes the issue.

PiperOrigin-RevId: 553469463
This commit is contained in:
Adam Paszke 2023-08-03 07:24:14 -07:00 committed by jax authors
parent e2634a2000
commit 0228bf7d3c
2 changed files with 1 additions and 7 deletions

View File

@ -918,18 +918,12 @@ py_test(
jax_test(
name = "compilation_cache_test",
srcs = ["compilation_cache_test.py"],
backend_tags = {
"tpu": ["nomsan"], # TODO(b/213388298): this test fails msan.
},
deps = ["//jax:compilation_cache_internal"],
)
jax_test(
name = "cache_key_test",
srcs = ["cache_key_test.py"],
backend_tags = {
"tpu": ["nomsan"], # TODO(b/213388298): this test fails msan.
},
deps = ["//jax:cache_key"],
)

View File

@ -264,7 +264,7 @@ class CacheKeyTest(jtu.JaxTestCase):
compile_options.executable_build_options.result_layout = shape
device_assignment = xla_client.DeviceAssignment.create(
np.ndarray(shape=(2, 2))
np.arange(4).reshape(2, 2)
)
compile_options.device_assignment = device_assignment
compile_options.executable_build_options.device_assignment = (