[Pallas TPU] Add DMA descriptor abstraction for constructing but not starting DMAs

PiperOrigin-RevId: 574210634
This commit is contained in:
Sharad Vikram 2023-10-17 11:19:28 -07:00 committed by jax authors
parent c16b893600
commit 86023f55ea
3 changed files with 103 additions and 40 deletions

View File

@ -26,6 +26,8 @@ from jax._src.pallas.mosaic.primitives import DeviceIdType
from jax._src.pallas.mosaic.primitives import async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy
from jax._src.pallas.mosaic.primitives import device_id
from jax._src.pallas.mosaic.primitives import make_async_copy
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
from jax._src.pallas.mosaic.primitives import repeat
from jax._src.pallas.mosaic.primitives import run_scoped
from jax._src.pallas.mosaic.primitives import semaphore_signal

View File

@ -165,14 +165,60 @@ def _semaphore_wait_abstract_eval(sem_aval: tpu_core.AbstractSemaphore, value):
@dataclasses.dataclass
class DMAFuture:
flat_args: Any
tree: Any
device_id_type: DeviceIdType | None
class AsyncCopyDescriptor:
src_ref: Any
src_indexer: indexing.NDIndexer
dst_ref: Any
dst_indexer: indexing.NDIndexer
dst_sem: int | jax.Array
src_sem: int | jax.Array | None
device_id: int | jax.Array | None
device_id_type: DeviceIdType = DeviceIdType.MESH
def __post_init__(self):
if (self.src_sem is None) ^ (self.device_id is None):
raise ValueError("Either both or neither `src_sem` and `device_id` "
"can be set.")
@property
def is_remote(self):
return self.src_sem is not None
def start(self):
flat_args, tree = tree_util.tree_flatten((
self.src_ref,
self.src_indexer,
self.dst_ref,
self.dst_indexer,
self.dst_sem,
self.src_sem,
self.device_id,
))
dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type)
def wait(self):
dma_wait_p.bind(*self.flat_args, tree=self.tree,
device_id_type=self.device_id_type)
if self.is_remote:
self.wait_send()
self.wait_recv()
def wait_recv(self):
wait_args, tree = tree_util.tree_flatten(
(self.dst_sem, self.dst_ref, self.dst_indexer)
)
dma_wait_p.bind(
*wait_args, tree=tree, device_id_type=self.device_id_type
)
def wait_send(self):
if not self.is_remote:
raise ValueError("Cannot `wait_send` on a local copy.")
wait_args, tree = tree_util.tree_flatten(
(self.src_sem, self.src_ref, self.src_indexer)
)
dma_wait_p.bind(
*wait_args, tree=tree, device_id_type=self.device_id_type
)
dma_start_p = jax_core.Primitive('dma_start')
dma_start_p.multiple_results = True
@ -236,35 +282,33 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn
def dma_start(src_ref, src_indices, dst_ref, dst_indices, sem) -> DMAFuture:
src_indexer = indexing.NDIndexer.from_indices_shape(src_indices,
src_ref.shape)
dst_indexer = indexing.NDIndexer.from_indices_shape(dst_indices,
dst_ref.shape)
args = (src_ref, src_indexer, dst_ref, dst_indexer, sem, None, None)
flat_args, tree = tree_util.tree_flatten(args)
dma_start_p.bind(*flat_args, tree=tree, device_id_type=None)
wait_args, tree = tree_util.tree_flatten((sem, dst_ref, dst_indexer))
return DMAFuture(wait_args, tree, None)
def remote_dma_start(src_ref, src_indices, dst_ref, dst_indices, src_sem,
dst_sem, device_id,
device_id_type: DeviceIdType) -> tuple[DMAFuture,
DMAFuture]:
src_indexer = indexing.NDIndexer.from_indices_shape(src_indices,
src_ref.shape)
dst_indexer = indexing.NDIndexer.from_indices_shape(dst_indices,
dst_ref.shape)
args = (src_ref, src_indexer, dst_ref, dst_indexer, dst_sem, src_sem,
device_id)
flat_args, tree = tree_util.tree_flatten(args)
dma_start_p.bind(*flat_args, tree=tree, device_id_type=device_id_type)
recv_wait_args = (dst_sem, dst_ref, dst_indexer)
recv_args, recv_tree = tree_util.tree_flatten(recv_wait_args)
send_wait_args = (src_sem, src_ref, src_indexer)
send_args, send_tree = tree_util.tree_flatten(send_wait_args)
return (DMAFuture(send_args, send_tree, device_id_type),
DMAFuture(recv_args, recv_tree, device_id_type))
def _make_copy_descriptor(
src_ref,
src_indices,
dst_ref,
dst_indices,
dst_sem,
src_sem,
device_id,
device_id_type,
) -> AsyncCopyDescriptor:
src_indexer = indexing.NDIndexer.from_indices_shape(
src_indices, src_ref.shape
)
dst_indexer = indexing.NDIndexer.from_indices_shape(
dst_indices, dst_ref.shape
)
return AsyncCopyDescriptor(
src_ref,
src_indexer,
dst_ref,
dst_indexer,
dst_sem,
src_sem,
device_id,
device_id_type=device_id_type,
)
dma_wait_p = jax_core.Primitive('dma_wait')
@ -296,18 +340,33 @@ def _get_ref_and_indexer(ref):
return ref.ref, ref.indexer
return ref, (slice(None),) * len(ref.shape)
def async_copy(src_ref, dst_ref, sem):
def make_async_copy(src_ref, dst_ref, sem):
"""Issues a DMA copying from src_ref to dst_ref."""
src_ref, src_indices = _get_ref_and_indexer(src_ref)
dst_ref, dst_indices = _get_ref_and_indexer(dst_ref)
return dma_start(src_ref, src_indices, dst_ref, dst_indices, sem)
return _make_copy_descriptor(src_ref, src_indices, dst_ref, dst_indices, sem,
None, None, DeviceIdType.MESH)
def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
def async_copy(src_ref, dst_ref, sem):
"""Issues a DMA copying from src_ref to dst_ref."""
copy_descriptor = make_async_copy(src_ref, dst_ref, sem)
copy_descriptor.start()
return copy_descriptor
def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
device_id_type: DeviceIdType = DeviceIdType.MESH):
src_ref, src_indices = _get_ref_and_indexer(src_ref)
dst_ref, dst_indices = _get_ref_and_indexer(dst_ref)
return remote_dma_start(src_ref, src_indices, dst_ref, dst_indices, send_sem,
recv_sem, device_id, device_id_type=device_id_type)
return _make_copy_descriptor(
src_ref, src_indices, dst_ref, dst_indices, recv_sem,
send_sem, device_id, device_id_type=device_id_type)
def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
device_id_type: DeviceIdType = DeviceIdType.MESH):
copy_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem,
device_id, device_id_type)
copy_descriptor.start()
return copy_descriptor
device_id_p = jax_core.Primitive('device_id')

View File

@ -26,6 +26,8 @@ from jax._src.pallas.mosaic import async_remote_copy
from jax._src.pallas.mosaic import device_id
from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic import make_async_copy
from jax._src.pallas.mosaic import make_async_remote_copy
from jax._src.pallas.mosaic import repeat
from jax._src.pallas.mosaic import run_scoped
from jax._src.pallas.mosaic import semaphore_signal