mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add some docstrings for remote DMAs and semaphore barriers.
PiperOrigin-RevId: 627037991
This commit is contained in:
parent
b79f3b77ef
commit
667a0c1fe5
@ -543,6 +543,24 @@ def async_copy(src_ref, dst_ref, sem):
|
||||
|
||||
def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
|
||||
device_id_type: DeviceIdType = DeviceIdType.MESH):
|
||||
"""Creates a description of a remote copy operation.
|
||||
|
||||
Copies data from src_ref on the current device to dst_ref on the device
|
||||
specified by device_id. Both semaphores should be waited on using the
|
||||
descriptor on both source and target devices.
|
||||
|
||||
Note that device_id can also refer to the current device.
|
||||
|
||||
Args:
|
||||
src_ref: The source Reference.
|
||||
dst_ref: The destination Reference.
|
||||
send_sem: The semaphore on the source device.
|
||||
recv_sem: The semaphore on the destination device.
|
||||
device_id: The device id of the destination device.
|
||||
device_id_type: The type of the device id.
|
||||
Returns:
|
||||
An AsyncCopyDescriptor.
|
||||
"""
|
||||
src_ref, src_indexers = _get_ref_and_indexers(src_ref)
|
||||
send_sem, send_sem_indexers = _get_ref_and_indexers(send_sem)
|
||||
dst_ref, dst_indexers = _get_ref_and_indexers(dst_ref)
|
||||
@ -576,4 +594,24 @@ def _get_barrier_semaphore_abstract_eval():
|
||||
)
|
||||
|
||||
def get_barrier_semaphore():
|
||||
"""Returns a barrier semaphore.
|
||||
|
||||
This function returns a barrier semaphore based on the collective_id of the
|
||||
current pallas kernel.
|
||||
|
||||
It's very important that the semaphore is wait-ed back down to 0, or else the
|
||||
semaphores will become corrupted.
|
||||
|
||||
It's also very important that the collective_id is different for each pallas
|
||||
kernel with communication. E.g. if you have two pallas kernels, one that syncs
|
||||
across the X axis of the device mesh and the second that syncs across the Y
|
||||
axis, they must have different collective_ids.
|
||||
However it is legal for two kernels that perform the same synchronization
|
||||
pattern (e.g. only communicating with neighbours on the same mesh axis)
|
||||
to share a collective_id. However, if in doubt, prefer not sharing
|
||||
collective_ids, as doing so incorrectly can lead to silent data corruption or
|
||||
crashes.
|
||||
Note that re-using the same collective_id doesn't guarantee that the same
|
||||
semaphore is provided by XLA.
|
||||
"""
|
||||
return get_barrier_semaphore_p.bind()
|
||||
|
Loading…
x
Reference in New Issue
Block a user