mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Added concurrent id_tap tests, disabled for GPU (#3690)
This commit is contained in:
parent
6b471e2ac6
commit
fdd7f0c857
@ -336,6 +336,77 @@ where: 3
|
||||
hcb.barrier_wait() # Wait for receivers to be done
|
||||
self.assertEqual(3, tap_count)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(
|
||||
testcase_name=f"_concurrent_{concurrent}",
|
||||
concurrent=concurrent)
|
||||
for concurrent in [True, False]))
|
||||
def test_multiple_tap(self, concurrent=False):
|
||||
"""Call id_tap multiple times, concurrently or in sequence. """
|
||||
if concurrent and jtu.device_under_test() == "gpu":
|
||||
# TODO(necula): it seems that on GPU if multiple host threads run
|
||||
# a jit computation, the mutliple computations are interleaved on the
|
||||
# GPU. This can result in the outfeed trains being interleaved, which
|
||||
# will trigger an error. The solution is to fix on GPU the receiving
|
||||
# logic so that we can outfeed the train as one tuple, and receive it
|
||||
# one piece as a time. Then the trains should be atomic.
|
||||
# See also b/160692602.
|
||||
raise SkipTest("concurrent id_tap not supported on GPU")
|
||||
received = set()
|
||||
count = 5
|
||||
def pause_tap(idx, **kwargs):
|
||||
received.add(int(idx))
|
||||
logging.info(f"Starting do_tap {idx}. Sleeping 1sec ...")
|
||||
time.sleep(0.3)
|
||||
logging.info(f"Finish do_tap {idx}")
|
||||
|
||||
def do_tap(idx):
|
||||
api.jit(lambda idx: hcb.id_tap(pause_tap, idx))(idx)
|
||||
|
||||
if concurrent:
|
||||
threads = [
|
||||
threading.Thread(
|
||||
name=f"enqueue_tap_{idx}", target=do_tap, args=(idx,))
|
||||
for idx in range(count)
|
||||
]
|
||||
[t.start() for t in threads]
|
||||
[t.join() for t in threads]
|
||||
else:
|
||||
for idx in range(count):
|
||||
do_tap(idx)
|
||||
|
||||
hcb.barrier_wait()
|
||||
self.assertEqual(received, set(range(count)))
|
||||
|
||||
# TODO(necula): see comment for test_multiple_tap.
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def test_multiple_barriers(self):
|
||||
"""Call barrier_wait concurrently."""
|
||||
|
||||
def pause_tap(*args, **kwargs):
|
||||
logging.info("pause_tap waiting")
|
||||
time.sleep(0.3)
|
||||
logging.info("pause_tap done")
|
||||
|
||||
def long_run(x):
|
||||
return hcb.id_tap(pause_tap, x)
|
||||
|
||||
api.jit(long_run)(5.)
|
||||
|
||||
def try_barrier(idx):
|
||||
logging.info(f"Starting test barrier {idx}")
|
||||
hcb.barrier_wait()
|
||||
logging.info(f"Finished test barrier {idx}")
|
||||
|
||||
threads = [
|
||||
threading.Thread(
|
||||
name=f"barrier_{idx}", target=try_barrier, args=(idx,))
|
||||
for idx in range(3)
|
||||
]
|
||||
[t.start() for t in threads]
|
||||
[t.join() for t in threads]
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(
|
||||
@ -917,32 +988,6 @@ what: x times i
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
|
||||
|
||||
def test_multiple_barriers(self):
|
||||
"""Call barrier_wait concurrently."""
|
||||
|
||||
def pause_tap(*args, **kwargs):
|
||||
logging.info("pause_tap waiting")
|
||||
time.sleep(2)
|
||||
logging.info("pause_tap done")
|
||||
|
||||
def long_run(x):
|
||||
return hcb.id_tap(pause_tap, x)
|
||||
|
||||
api.jit(long_run)(5.)
|
||||
|
||||
def try_barrier(idx):
|
||||
logging.info(f"Starting test barrier {idx}")
|
||||
hcb.barrier_wait()
|
||||
logging.info(f"Finished test barrier {idx}")
|
||||
|
||||
threads = [
|
||||
threading.Thread(
|
||||
name=f"barrier_{idx}", target=try_barrier, args=(idx,))
|
||||
for idx in range(3)
|
||||
]
|
||||
[t.start() for t in threads]
|
||||
[t.join() for t in threads]
|
||||
|
||||
def test_error_bad_consumer_id(self):
|
||||
"""Try to use reserved consumer ID 0.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user