mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove code for compatibility with jaxlib < 0.5.0.
PiperOrigin-RevId: 724045185
This commit is contained in:
parent
1b7b04f7db
commit
840192d39a
@ -1029,58 +1029,42 @@ def sample_product(*args, **kw):
|
||||
# We use a reader-writer lock to protect test execution. Tests that may run in
|
||||
# parallel acquire a read lock; tests that are not thread-safe acquire a write
|
||||
# lock.
|
||||
if hasattr(util, 'Mutex'):
|
||||
_test_rwlock = util.Mutex()
|
||||
_test_rwlock = util.Mutex()
|
||||
|
||||
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
|
||||
if getattr(test.__class__, "thread_hostile", False):
|
||||
_test_rwlock.writer_lock()
|
||||
try:
|
||||
test(result) # type: ignore
|
||||
finally:
|
||||
_test_rwlock.writer_unlock()
|
||||
else:
|
||||
_test_rwlock.reader_lock()
|
||||
try:
|
||||
test(result) # type: ignore
|
||||
finally:
|
||||
_test_rwlock.reader_unlock()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def thread_unsafe_test():
|
||||
"""Decorator for tests that are not thread-safe.
|
||||
|
||||
Note: this decorator (naturally) only applies to what it wraps, not to, say,
|
||||
code in separate setUp() or tearDown() methods.
|
||||
"""
|
||||
if TEST_NUM_THREADS.value <= 0:
|
||||
yield
|
||||
return
|
||||
|
||||
_test_rwlock.assert_reader_held()
|
||||
_test_rwlock.reader_unlock()
|
||||
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
|
||||
if getattr(test.__class__, "thread_hostile", False):
|
||||
_test_rwlock.writer_lock()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_test_rwlock.writer_unlock()
|
||||
_test_rwlock.reader_lock()
|
||||
else:
|
||||
# TODO(phawkins): remove this branch when jaxlib 0.5.0 is the minimum.
|
||||
_test_rwlock = threading.Lock()
|
||||
|
||||
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
|
||||
_test_rwlock.acquire()
|
||||
try:
|
||||
test(result) # type: ignore
|
||||
finally:
|
||||
_test_rwlock.release()
|
||||
_test_rwlock.writer_unlock()
|
||||
else:
|
||||
_test_rwlock.reader_lock()
|
||||
try:
|
||||
test(result) # type: ignore
|
||||
finally:
|
||||
_test_rwlock.reader_unlock()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def thread_unsafe_test():
|
||||
yield # No reader-writer lock, so we get no parallelism.
|
||||
@contextmanager
|
||||
def thread_unsafe_test():
|
||||
"""Decorator for tests that are not thread-safe.
|
||||
|
||||
Note: this decorator (naturally) only applies to what it wraps, not to, say,
|
||||
code in separate setUp() or tearDown() methods.
|
||||
"""
|
||||
if TEST_NUM_THREADS.value <= 0:
|
||||
yield
|
||||
return
|
||||
|
||||
_test_rwlock.assert_reader_held()
|
||||
_test_rwlock.reader_unlock()
|
||||
_test_rwlock.writer_lock()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_test_rwlock.writer_unlock()
|
||||
_test_rwlock.reader_lock()
|
||||
|
||||
|
||||
def thread_unsafe_test_class():
|
||||
|
Loading…
x
Reference in New Issue
Block a user