Add some docstrings for remote DMAs and semaphore barriers.

PiperOrigin-RevId: 627037991
This commit is contained in:
jax authors 2024-04-22 08:00:56 -07:00
parent b79f3b77ef
commit 667a0c1fe5

View File

@ -543,6 +543,24 @@ def async_copy(src_ref, dst_ref, sem):
def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
device_id_type: DeviceIdType = DeviceIdType.MESH):
"""Creates a description of a remote copy operation.
Copies data from src_ref on the current device to dst_ref on the device
specified by device_id. Both semaphores should be waited on using the
descriptor on both source and target devices.
Note that device_id can also refer to the current device.
Args:
src_ref: The source Reference.
dst_ref: The destination Reference.
send_sem: The semaphore on the source device.
recv_sem: The semaphore on the destination device.
device_id: The device id of the destination device.
device_id_type: The type of the device id.
Returns:
An AsyncCopyDescriptor.
"""
src_ref, src_indexers = _get_ref_and_indexers(src_ref)
send_sem, send_sem_indexers = _get_ref_and_indexers(send_sem)
dst_ref, dst_indexers = _get_ref_and_indexers(dst_ref)
@ -576,4 +594,24 @@ def _get_barrier_semaphore_abstract_eval():
)
def get_barrier_semaphore():
"""Returns a barrier semaphore.
This function returns a barrier semaphore based on the collective_id of the
current pallas kernel.
It's very important that the semaphore is wait-ed back down to 0, or else the
semaphores will become corrupted.
It's also very important that the collective_id is different for each pallas
kernel with communication. E.g. if you have two pallas kernels, one that syncs
across the X axis of the device mesh and the second that syncs across the Y
axis, they must have different collective_ids.
However it is legal for two kernels that perform the same synchronization
pattern (e.g. only communicating with neighbours on the same mesh axis)
to share a collective_id. However, if in doubt, prefer not sharing
collective_ids, as doing so incorrectly can lead to silent data corruption or
crashes.
Note that re-using the same collective_id doesn't guarantee that the same
semaphore is provided by XLA.
"""
return get_barrier_semaphore_p.bind()