From 86023f55eae79e0cb6e808465ec6b658ae7b1dd1 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 17 Oct 2023 11:19:28 -0700 Subject: [PATCH] [Pallas TPU] Add DMA descriptor abstraction for constructing but not starting DMAs PiperOrigin-RevId: 574210634 --- jax/_src/pallas/mosaic/__init__.py | 2 + jax/_src/pallas/mosaic/primitives.py | 139 +++++++++++++++++++-------- jax/experimental/pallas/tpu.py | 2 + 3 files changed, 103 insertions(+), 40 deletions(-) diff --git a/jax/_src/pallas/mosaic/__init__.py b/jax/_src/pallas/mosaic/__init__.py index ed1c2afb9..18d724f07 100644 --- a/jax/_src/pallas/mosaic/__init__.py +++ b/jax/_src/pallas/mosaic/__init__.py @@ -26,6 +26,8 @@ from jax._src.pallas.mosaic.primitives import DeviceIdType from jax._src.pallas.mosaic.primitives import async_copy from jax._src.pallas.mosaic.primitives import async_remote_copy from jax._src.pallas.mosaic.primitives import device_id +from jax._src.pallas.mosaic.primitives import make_async_copy +from jax._src.pallas.mosaic.primitives import make_async_remote_copy from jax._src.pallas.mosaic.primitives import repeat from jax._src.pallas.mosaic.primitives import run_scoped from jax._src.pallas.mosaic.primitives import semaphore_signal diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 3175ad5e8..f040374bf 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -165,14 +165,60 @@ def _semaphore_wait_abstract_eval(sem_aval: tpu_core.AbstractSemaphore, value): @dataclasses.dataclass -class DMAFuture: - flat_args: Any - tree: Any - device_id_type: DeviceIdType | None +class AsyncCopyDescriptor: + src_ref: Any + src_indexer: indexing.NDIndexer + dst_ref: Any + dst_indexer: indexing.NDIndexer + dst_sem: int | jax.Array + src_sem: int | jax.Array | None + device_id: int | jax.Array | None + device_id_type: DeviceIdType = DeviceIdType.MESH + + def __post_init__(self): + if (self.src_sem is None) ^ (self.device_id is None): + raise ValueError("Either both or neither `src_sem` and `device_id` " + "can be set.") + + @property + def is_remote(self): + return self.src_sem is not None + + def start(self): + flat_args, tree = tree_util.tree_flatten(( + self.src_ref, + self.src_indexer, + self.dst_ref, + self.dst_indexer, + self.dst_sem, + self.src_sem, + self.device_id, + )) + dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type) def wait(self): - dma_wait_p.bind(*self.flat_args, tree=self.tree, - device_id_type=self.device_id_type) + if self.is_remote: + self.wait_send() + self.wait_recv() + + def wait_recv(self): + wait_args, tree = tree_util.tree_flatten( + (self.dst_sem, self.dst_ref, self.dst_indexer) + ) + dma_wait_p.bind( + *wait_args, tree=tree, device_id_type=self.device_id_type + ) + + def wait_send(self): + if not self.is_remote: + raise ValueError("Cannot `wait_send` on a local copy.") + wait_args, tree = tree_util.tree_flatten( + (self.src_sem, self.src_ref, self.src_indexer) + ) + dma_wait_p.bind( + *wait_args, tree=tree, device_id_type=self.device_id_type + ) + dma_start_p = jax_core.Primitive('dma_start') dma_start_p.multiple_results = True @@ -236,35 +282,33 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn -def dma_start(src_ref, src_indices, dst_ref, dst_indices, sem) -> DMAFuture: - src_indexer = indexing.NDIndexer.from_indices_shape(src_indices, - src_ref.shape) - dst_indexer = indexing.NDIndexer.from_indices_shape(dst_indices, - dst_ref.shape) - args = (src_ref, src_indexer, dst_ref, dst_indexer, sem, None, None) - flat_args, tree = tree_util.tree_flatten(args) - dma_start_p.bind(*flat_args, tree=tree, device_id_type=None) - wait_args, tree = tree_util.tree_flatten((sem, dst_ref, dst_indexer)) - return DMAFuture(wait_args, tree, None) -def remote_dma_start(src_ref, src_indices, dst_ref, dst_indices, src_sem, - dst_sem, device_id, - device_id_type: DeviceIdType) -> tuple[DMAFuture, - DMAFuture]: - src_indexer = indexing.NDIndexer.from_indices_shape(src_indices, - src_ref.shape) - dst_indexer = indexing.NDIndexer.from_indices_shape(dst_indices, - dst_ref.shape) - args = (src_ref, src_indexer, dst_ref, dst_indexer, dst_sem, src_sem, - device_id) - flat_args, tree = tree_util.tree_flatten(args) - dma_start_p.bind(*flat_args, tree=tree, device_id_type=device_id_type) - recv_wait_args = (dst_sem, dst_ref, dst_indexer) - recv_args, recv_tree = tree_util.tree_flatten(recv_wait_args) - send_wait_args = (src_sem, src_ref, src_indexer) - send_args, send_tree = tree_util.tree_flatten(send_wait_args) - return (DMAFuture(send_args, send_tree, device_id_type), - DMAFuture(recv_args, recv_tree, device_id_type)) +def _make_copy_descriptor( + src_ref, + src_indices, + dst_ref, + dst_indices, + dst_sem, + src_sem, + device_id, + device_id_type, +) -> AsyncCopyDescriptor: + src_indexer = indexing.NDIndexer.from_indices_shape( + src_indices, src_ref.shape + ) + dst_indexer = indexing.NDIndexer.from_indices_shape( + dst_indices, dst_ref.shape + ) + return AsyncCopyDescriptor( + src_ref, + src_indexer, + dst_ref, + dst_indexer, + dst_sem, + src_sem, + device_id, + device_id_type=device_id_type, + ) dma_wait_p = jax_core.Primitive('dma_wait') @@ -296,18 +340,33 @@ def _get_ref_and_indexer(ref): return ref.ref, ref.indexer return ref, (slice(None),) * len(ref.shape) -def async_copy(src_ref, dst_ref, sem): +def make_async_copy(src_ref, dst_ref, sem): """Issues a DMA copying from src_ref to dst_ref.""" src_ref, src_indices = _get_ref_and_indexer(src_ref) dst_ref, dst_indices = _get_ref_and_indexer(dst_ref) - return dma_start(src_ref, src_indices, dst_ref, dst_indices, sem) + return _make_copy_descriptor(src_ref, src_indices, dst_ref, dst_indices, sem, + None, None, DeviceIdType.MESH) + +def async_copy(src_ref, dst_ref, sem): + """Issues a DMA copying from src_ref to dst_ref.""" + copy_descriptor = make_async_copy(src_ref, dst_ref, sem) + copy_descriptor.start() + return copy_descriptor + +def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, + device_id_type: DeviceIdType = DeviceIdType.MESH): + src_ref, src_indices = _get_ref_and_indexer(src_ref) + dst_ref, dst_indices = _get_ref_and_indexer(dst_ref) + return _make_copy_descriptor( + src_ref, src_indices, dst_ref, dst_indices, recv_sem, + send_sem, device_id, device_id_type=device_id_type) def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type: DeviceIdType = DeviceIdType.MESH): - src_ref, src_indices = _get_ref_and_indexer(src_ref) - dst_ref, dst_indices = _get_ref_and_indexer(dst_ref) - return remote_dma_start(src_ref, src_indices, dst_ref, dst_indices, send_sem, - recv_sem, device_id, device_id_type=device_id_type) + copy_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, + device_id, device_id_type) + copy_descriptor.start() + return copy_descriptor device_id_p = jax_core.Primitive('device_id') diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 54ed013cc..37b3dd670 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -26,6 +26,8 @@ from jax._src.pallas.mosaic import async_remote_copy from jax._src.pallas.mosaic import device_id from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata +from jax._src.pallas.mosaic import make_async_copy +from jax._src.pallas.mosaic import make_async_remote_copy from jax._src.pallas.mosaic import repeat from jax._src.pallas.mosaic import run_scoped from jax._src.pallas.mosaic import semaphore_signal