diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions.txt index 2224edff4..71542ea5d 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions.txt @@ -32,8 +32,6 @@ race:split_keys_entry_added # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx -# https://github.com/python/cpython/issues/130571 -race:_PyObject_GetMethod # Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. race:heevd_ffi @@ -63,3 +61,6 @@ race:gemm_oncopy # https://github.com/python/cpython/issues/128133 # race:bytes_hash + +# https://github.com/python/cpython/issues/130571 +# race:_PyObject_GetMethod diff --git a/tests/pjit_test.py b/tests/pjit_test.py index f30de44d0..f690d5744 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7495,10 +7495,12 @@ class UtilTest(jtu.JaxTestCase): ), ) + @jtu.thread_unsafe_test() def test_op_sharding_cache_on_mesh_pspec_sharding(self): ndim = 2 mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps1 = NamedSharding(mesh, P('x', 'y')) + sharding_impls.named_sharding_to_xla_hlo_sharding.cache_clear() op1 = mps1._to_xla_hlo_sharding(ndim) cache_info1 = sharding_impls.named_sharding_to_xla_hlo_sharding.cache_info()