1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

Remove code for compatibility with jaxlib < 0.5.0.

PiperOrigin-RevId: 724045185
This commit is contained in:
Peter Hawkins 2025-02-06 13:08:04 -08:00 committed by jax authors
parent 1b7b04f7db
commit 840192d39a

@ -1029,58 +1029,42 @@ def sample_product(*args, **kw):
# We use a reader-writer lock to protect test execution. Tests that may run in
# parallel acquire a read lock; tests that are not thread-safe acquire a write
# lock.
if hasattr(util, 'Mutex'):
_test_rwlock = util.Mutex()
_test_rwlock = util.Mutex()
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
if getattr(test.__class__, "thread_hostile", False):
_test_rwlock.writer_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.writer_unlock()
else:
_test_rwlock.reader_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.reader_unlock()
@contextmanager
def thread_unsafe_test():
"""Decorator for tests that are not thread-safe.
Note: this decorator (naturally) only applies to what it wraps, not to, say,
code in separate setUp() or tearDown() methods.
"""
if TEST_NUM_THREADS.value <= 0:
yield
return
_test_rwlock.assert_reader_held()
_test_rwlock.reader_unlock()
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
if getattr(test.__class__, "thread_hostile", False):
_test_rwlock.writer_lock()
try:
yield
finally:
_test_rwlock.writer_unlock()
_test_rwlock.reader_lock()
else:
# TODO(phawkins): remove this branch when jaxlib 0.5.0 is the minimum.
_test_rwlock = threading.Lock()
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
_test_rwlock.acquire()
try:
test(result) # type: ignore
finally:
_test_rwlock.release()
_test_rwlock.writer_unlock()
else:
_test_rwlock.reader_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.reader_unlock()
@contextmanager
def thread_unsafe_test():
yield # No reader-writer lock, so we get no parallelism.
@contextmanager
def thread_unsafe_test():
"""Decorator for tests that are not thread-safe.
Note: this decorator (naturally) only applies to what it wraps, not to, say,
code in separate setUp() or tearDown() methods.
"""
if TEST_NUM_THREADS.value <= 0:
yield
return
_test_rwlock.assert_reader_held()
_test_rwlock.reader_unlock()
_test_rwlock.writer_lock()
try:
yield
finally:
_test_rwlock.writer_unlock()
_test_rwlock.reader_lock()
def thread_unsafe_test_class():