Added concurrent id_tap tests, disabled for GPU (#3690)

This commit is contained in:
George Necula 2020-07-08 16:08:54 +03:00 committed by GitHub
parent 6b471e2ac6
commit fdd7f0c857
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -336,6 +336,77 @@ where: 3
hcb.barrier_wait() # Wait for receivers to be done
self.assertEqual(3, tap_count)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_concurrent_{concurrent}",
concurrent=concurrent)
for concurrent in [True, False]))
def test_multiple_tap(self, concurrent=False):
"""Call id_tap multiple times, concurrently or in sequence. """
if concurrent and jtu.device_under_test() == "gpu":
# TODO(necula): it seems that on GPU if multiple host threads run
# a jit computation, the mutliple computations are interleaved on the
# GPU. This can result in the outfeed trains being interleaved, which
# will trigger an error. The solution is to fix on GPU the receiving
# logic so that we can outfeed the train as one tuple, and receive it
# one piece as a time. Then the trains should be atomic.
# See also b/160692602.
raise SkipTest("concurrent id_tap not supported on GPU")
received = set()
count = 5
def pause_tap(idx, **kwargs):
received.add(int(idx))
logging.info(f"Starting do_tap {idx}. Sleeping 1sec ...")
time.sleep(0.3)
logging.info(f"Finish do_tap {idx}")
def do_tap(idx):
api.jit(lambda idx: hcb.id_tap(pause_tap, idx))(idx)
if concurrent:
threads = [
threading.Thread(
name=f"enqueue_tap_{idx}", target=do_tap, args=(idx,))
for idx in range(count)
]
[t.start() for t in threads]
[t.join() for t in threads]
else:
for idx in range(count):
do_tap(idx)
hcb.barrier_wait()
self.assertEqual(received, set(range(count)))
# TODO(necula): see comment for test_multiple_tap.
@jtu.skip_on_devices("gpu")
def test_multiple_barriers(self):
"""Call barrier_wait concurrently."""
def pause_tap(*args, **kwargs):
logging.info("pause_tap waiting")
time.sleep(0.3)
logging.info("pause_tap done")
def long_run(x):
return hcb.id_tap(pause_tap, x)
api.jit(long_run)(5.)
def try_barrier(idx):
logging.info(f"Starting test barrier {idx}")
hcb.barrier_wait()
logging.info(f"Finished test barrier {idx}")
threads = [
threading.Thread(
name=f"barrier_{idx}", target=try_barrier, args=(idx,))
for idx in range(3)
]
[t.start() for t in threads]
[t.join() for t in threads]
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
@ -917,32 +988,6 @@ what: x times i
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
def test_multiple_barriers(self):
"""Call barrier_wait concurrently."""
def pause_tap(*args, **kwargs):
logging.info("pause_tap waiting")
time.sleep(2)
logging.info("pause_tap done")
def long_run(x):
return hcb.id_tap(pause_tap, x)
api.jit(long_run)(5.)
def try_barrier(idx):
logging.info(f"Starting test barrier {idx}")
hcb.barrier_wait()
logging.info(f"Finished test barrier {idx}")
threads = [
threading.Thread(
name=f"barrier_{idx}", target=try_barrier, args=(idx,))
for idx in range(3)
]
[t.start() for t in threads]
[t.join() for t in threads]
def test_error_bad_consumer_id(self):
"""Try to use reserved consumer ID 0.