Fix pjit outfeed test avoid potential deadlocks.

PiperOrigin-RevId: 529076350
This commit is contained in:
Rahul Joshi 2023-05-03 06:50:48 -07:00 committed by jax authors
parent 545c483e50
commit 9d750ae97d

View File

@ -832,18 +832,36 @@ class PJitTest(jtu.BufferDonationTestCase):
execution = threading.Thread(target=_dispatch)
execution.start()
def check_outfeed(d, x):
y, = d.transfer_from_outfeed(
xc.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent())
self.assertAllClose(x, y, check_dtypes=True)
# Check the expected outfeed for all devices.
def check_outfeed(x_fn):
for didx, d in enumerate(devices):
x = x_fn(didx)
y, = d.transfer_from_outfeed(
xc.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent())
self.assertAllClose(x, y, check_dtypes=True)
logging.info('Transferring from outfeed for the pjit call')
for didx, d in enumerate(devices):
# Transfer the whole array from all devices for replicated.
check_outfeed(d, x)
# For sharded outfeed, the results are sliced.
check_outfeed(d, x[3 * didx:3 * didx + 3, :])
check_outfeed(d, x[:, 5 * didx:5 * didx + 5])
# Note, when checking results of multiple outfeeds, the loop structure
# should be such that we check a given outfeed for all devices before
# moving on to the next outfeed. If there are any collectives generated
# by pjit, a loop structutre like:
# for each device:
# check outfeed#0;
# check outfeed#1;
#
# Could cause a deadlock if there is a collective scheduled between the
# 2 outfeeds, as device #0, after processing outfeed#0 will execute the
# collective, waiting for other devices to join, but other devices won't
# execute their collective until their outfeed#0 is executed. This is
# because, for GPU for example, execution of an outfeed on GPU is blocked
# till the corresponding `transfer_from_outfeed` is executed on the host.
# Transfer the whole array from all devices for replicated.
check_outfeed(lambda didx: x)
# For sharded outfeed, the results are sliced.
check_outfeed(lambda didx: x[3 * didx:3 * didx + 3, :])
check_outfeed(lambda didx: x[:, 5 * didx:5 * didx + 5])
execution.join()