mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix MSAN errors in cache_key_test
The device_assignment array was never initialized, causing MSAN errors. Replacing it with np.arange fixes the issue. PiperOrigin-RevId: 553469463
This commit is contained in:
parent
e2634a2000
commit
0228bf7d3c
@ -918,18 +918,12 @@ py_test(
|
||||
jax_test(
|
||||
name = "compilation_cache_test",
|
||||
srcs = ["compilation_cache_test.py"],
|
||||
backend_tags = {
|
||||
"tpu": ["nomsan"], # TODO(b/213388298): this test fails msan.
|
||||
},
|
||||
deps = ["//jax:compilation_cache_internal"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "cache_key_test",
|
||||
srcs = ["cache_key_test.py"],
|
||||
backend_tags = {
|
||||
"tpu": ["nomsan"], # TODO(b/213388298): this test fails msan.
|
||||
},
|
||||
deps = ["//jax:cache_key"],
|
||||
)
|
||||
|
||||
|
@ -264,7 +264,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
compile_options.executable_build_options.result_layout = shape
|
||||
|
||||
device_assignment = xla_client.DeviceAssignment.create(
|
||||
np.ndarray(shape=(2, 2))
|
||||
np.arange(4).reshape(2, 2)
|
||||
)
|
||||
compile_options.device_assignment = device_assignment
|
||||
compile_options.executable_build_options.device_assignment = (
|
||||
|
Loading…
x
Reference in New Issue
Block a user