mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Pallas TPU] Add DMA descriptor abstraction for constructing but not starting DMAs
PiperOrigin-RevId: 574210634
This commit is contained in:
parent
c16b893600
commit
86023f55ea
@ -26,6 +26,8 @@ from jax._src.pallas.mosaic.primitives import DeviceIdType
|
||||
from jax._src.pallas.mosaic.primitives import async_copy
|
||||
from jax._src.pallas.mosaic.primitives import async_remote_copy
|
||||
from jax._src.pallas.mosaic.primitives import device_id
|
||||
from jax._src.pallas.mosaic.primitives import make_async_copy
|
||||
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
|
||||
from jax._src.pallas.mosaic.primitives import repeat
|
||||
from jax._src.pallas.mosaic.primitives import run_scoped
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_signal
|
||||
|
@ -165,14 +165,60 @@ def _semaphore_wait_abstract_eval(sem_aval: tpu_core.AbstractSemaphore, value):
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DMAFuture:
|
||||
flat_args: Any
|
||||
tree: Any
|
||||
device_id_type: DeviceIdType | None
|
||||
class AsyncCopyDescriptor:
|
||||
src_ref: Any
|
||||
src_indexer: indexing.NDIndexer
|
||||
dst_ref: Any
|
||||
dst_indexer: indexing.NDIndexer
|
||||
dst_sem: int | jax.Array
|
||||
src_sem: int | jax.Array | None
|
||||
device_id: int | jax.Array | None
|
||||
device_id_type: DeviceIdType = DeviceIdType.MESH
|
||||
|
||||
def __post_init__(self):
|
||||
if (self.src_sem is None) ^ (self.device_id is None):
|
||||
raise ValueError("Either both or neither `src_sem` and `device_id` "
|
||||
"can be set.")
|
||||
|
||||
@property
|
||||
def is_remote(self):
|
||||
return self.src_sem is not None
|
||||
|
||||
def start(self):
|
||||
flat_args, tree = tree_util.tree_flatten((
|
||||
self.src_ref,
|
||||
self.src_indexer,
|
||||
self.dst_ref,
|
||||
self.dst_indexer,
|
||||
self.dst_sem,
|
||||
self.src_sem,
|
||||
self.device_id,
|
||||
))
|
||||
dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type)
|
||||
|
||||
def wait(self):
|
||||
dma_wait_p.bind(*self.flat_args, tree=self.tree,
|
||||
device_id_type=self.device_id_type)
|
||||
if self.is_remote:
|
||||
self.wait_send()
|
||||
self.wait_recv()
|
||||
|
||||
def wait_recv(self):
|
||||
wait_args, tree = tree_util.tree_flatten(
|
||||
(self.dst_sem, self.dst_ref, self.dst_indexer)
|
||||
)
|
||||
dma_wait_p.bind(
|
||||
*wait_args, tree=tree, device_id_type=self.device_id_type
|
||||
)
|
||||
|
||||
def wait_send(self):
|
||||
if not self.is_remote:
|
||||
raise ValueError("Cannot `wait_send` on a local copy.")
|
||||
wait_args, tree = tree_util.tree_flatten(
|
||||
(self.src_sem, self.src_ref, self.src_indexer)
|
||||
)
|
||||
dma_wait_p.bind(
|
||||
*wait_args, tree=tree, device_id_type=self.device_id_type
|
||||
)
|
||||
|
||||
|
||||
dma_start_p = jax_core.Primitive('dma_start')
|
||||
dma_start_p.multiple_results = True
|
||||
@ -236,35 +282,33 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
|
||||
|
||||
jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn
|
||||
|
||||
def dma_start(src_ref, src_indices, dst_ref, dst_indices, sem) -> DMAFuture:
|
||||
src_indexer = indexing.NDIndexer.from_indices_shape(src_indices,
|
||||
src_ref.shape)
|
||||
dst_indexer = indexing.NDIndexer.from_indices_shape(dst_indices,
|
||||
dst_ref.shape)
|
||||
args = (src_ref, src_indexer, dst_ref, dst_indexer, sem, None, None)
|
||||
flat_args, tree = tree_util.tree_flatten(args)
|
||||
dma_start_p.bind(*flat_args, tree=tree, device_id_type=None)
|
||||
wait_args, tree = tree_util.tree_flatten((sem, dst_ref, dst_indexer))
|
||||
return DMAFuture(wait_args, tree, None)
|
||||
|
||||
def remote_dma_start(src_ref, src_indices, dst_ref, dst_indices, src_sem,
|
||||
dst_sem, device_id,
|
||||
device_id_type: DeviceIdType) -> tuple[DMAFuture,
|
||||
DMAFuture]:
|
||||
src_indexer = indexing.NDIndexer.from_indices_shape(src_indices,
|
||||
src_ref.shape)
|
||||
dst_indexer = indexing.NDIndexer.from_indices_shape(dst_indices,
|
||||
dst_ref.shape)
|
||||
args = (src_ref, src_indexer, dst_ref, dst_indexer, dst_sem, src_sem,
|
||||
device_id)
|
||||
flat_args, tree = tree_util.tree_flatten(args)
|
||||
dma_start_p.bind(*flat_args, tree=tree, device_id_type=device_id_type)
|
||||
recv_wait_args = (dst_sem, dst_ref, dst_indexer)
|
||||
recv_args, recv_tree = tree_util.tree_flatten(recv_wait_args)
|
||||
send_wait_args = (src_sem, src_ref, src_indexer)
|
||||
send_args, send_tree = tree_util.tree_flatten(send_wait_args)
|
||||
return (DMAFuture(send_args, send_tree, device_id_type),
|
||||
DMAFuture(recv_args, recv_tree, device_id_type))
|
||||
def _make_copy_descriptor(
|
||||
src_ref,
|
||||
src_indices,
|
||||
dst_ref,
|
||||
dst_indices,
|
||||
dst_sem,
|
||||
src_sem,
|
||||
device_id,
|
||||
device_id_type,
|
||||
) -> AsyncCopyDescriptor:
|
||||
src_indexer = indexing.NDIndexer.from_indices_shape(
|
||||
src_indices, src_ref.shape
|
||||
)
|
||||
dst_indexer = indexing.NDIndexer.from_indices_shape(
|
||||
dst_indices, dst_ref.shape
|
||||
)
|
||||
return AsyncCopyDescriptor(
|
||||
src_ref,
|
||||
src_indexer,
|
||||
dst_ref,
|
||||
dst_indexer,
|
||||
dst_sem,
|
||||
src_sem,
|
||||
device_id,
|
||||
device_id_type=device_id_type,
|
||||
)
|
||||
|
||||
|
||||
dma_wait_p = jax_core.Primitive('dma_wait')
|
||||
@ -296,18 +340,33 @@ def _get_ref_and_indexer(ref):
|
||||
return ref.ref, ref.indexer
|
||||
return ref, (slice(None),) * len(ref.shape)
|
||||
|
||||
def async_copy(src_ref, dst_ref, sem):
|
||||
def make_async_copy(src_ref, dst_ref, sem):
|
||||
"""Issues a DMA copying from src_ref to dst_ref."""
|
||||
src_ref, src_indices = _get_ref_and_indexer(src_ref)
|
||||
dst_ref, dst_indices = _get_ref_and_indexer(dst_ref)
|
||||
return dma_start(src_ref, src_indices, dst_ref, dst_indices, sem)
|
||||
return _make_copy_descriptor(src_ref, src_indices, dst_ref, dst_indices, sem,
|
||||
None, None, DeviceIdType.MESH)
|
||||
|
||||
def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
|
||||
def async_copy(src_ref, dst_ref, sem):
|
||||
"""Issues a DMA copying from src_ref to dst_ref."""
|
||||
copy_descriptor = make_async_copy(src_ref, dst_ref, sem)
|
||||
copy_descriptor.start()
|
||||
return copy_descriptor
|
||||
|
||||
def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
|
||||
device_id_type: DeviceIdType = DeviceIdType.MESH):
|
||||
src_ref, src_indices = _get_ref_and_indexer(src_ref)
|
||||
dst_ref, dst_indices = _get_ref_and_indexer(dst_ref)
|
||||
return remote_dma_start(src_ref, src_indices, dst_ref, dst_indices, send_sem,
|
||||
recv_sem, device_id, device_id_type=device_id_type)
|
||||
return _make_copy_descriptor(
|
||||
src_ref, src_indices, dst_ref, dst_indices, recv_sem,
|
||||
send_sem, device_id, device_id_type=device_id_type)
|
||||
|
||||
def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
|
||||
device_id_type: DeviceIdType = DeviceIdType.MESH):
|
||||
copy_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem,
|
||||
device_id, device_id_type)
|
||||
copy_descriptor.start()
|
||||
return copy_descriptor
|
||||
|
||||
device_id_p = jax_core.Primitive('device_id')
|
||||
|
||||
|
@ -26,6 +26,8 @@ from jax._src.pallas.mosaic import async_remote_copy
|
||||
from jax._src.pallas.mosaic import device_id
|
||||
from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata
|
||||
from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata
|
||||
from jax._src.pallas.mosaic import make_async_copy
|
||||
from jax._src.pallas.mosaic import make_async_remote_copy
|
||||
from jax._src.pallas.mosaic import repeat
|
||||
from jax._src.pallas.mosaic import run_scoped
|
||||
from jax._src.pallas.mosaic import semaphore_signal
|
||||
|
Loading…
x
Reference in New Issue
Block a user