mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix pjit outfeed test avoid potential deadlocks.
PiperOrigin-RevId: 529076350
This commit is contained in:
parent
545c483e50
commit
9d750ae97d
@ -832,18 +832,36 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
execution = threading.Thread(target=_dispatch)
|
||||
execution.start()
|
||||
|
||||
def check_outfeed(d, x):
|
||||
y, = d.transfer_from_outfeed(
|
||||
xc.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent())
|
||||
self.assertAllClose(x, y, check_dtypes=True)
|
||||
# Check the expected outfeed for all devices.
|
||||
def check_outfeed(x_fn):
|
||||
for didx, d in enumerate(devices):
|
||||
x = x_fn(didx)
|
||||
y, = d.transfer_from_outfeed(
|
||||
xc.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent())
|
||||
self.assertAllClose(x, y, check_dtypes=True)
|
||||
|
||||
logging.info('Transferring from outfeed for the pjit call')
|
||||
for didx, d in enumerate(devices):
|
||||
# Transfer the whole array from all devices for replicated.
|
||||
check_outfeed(d, x)
|
||||
# For sharded outfeed, the results are sliced.
|
||||
check_outfeed(d, x[3 * didx:3 * didx + 3, :])
|
||||
check_outfeed(d, x[:, 5 * didx:5 * didx + 5])
|
||||
|
||||
# Note, when checking results of multiple outfeeds, the loop structure
|
||||
# should be such that we check a given outfeed for all devices before
|
||||
# moving on to the next outfeed. If there are any collectives generated
|
||||
# by pjit, a loop structutre like:
|
||||
# for each device:
|
||||
# check outfeed#0;
|
||||
# check outfeed#1;
|
||||
#
|
||||
# Could cause a deadlock if there is a collective scheduled between the
|
||||
# 2 outfeeds, as device #0, after processing outfeed#0 will execute the
|
||||
# collective, waiting for other devices to join, but other devices won't
|
||||
# execute their collective until their outfeed#0 is executed. This is
|
||||
# because, for GPU for example, execution of an outfeed on GPU is blocked
|
||||
# till the corresponding `transfer_from_outfeed` is executed on the host.
|
||||
|
||||
# Transfer the whole array from all devices for replicated.
|
||||
check_outfeed(lambda didx: x)
|
||||
# For sharded outfeed, the results are sliced.
|
||||
check_outfeed(lambda didx: x[3 * didx:3 * didx + 3, :])
|
||||
check_outfeed(lambda didx: x[:, 5 * didx:5 * didx + 5])
|
||||
|
||||
execution.join()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user