mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 10:06:07 +00:00
fix shard_args logic, closes #1688
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
This commit is contained in:
parent
ec89eb9e5f
commit
c19e65b7ab
@ -16,7 +16,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
from collections import namedtuple, defaultdict
|
||||
from contextlib import contextmanager
|
||||
import itertools as it
|
||||
import operator as op
|
||||
@ -49,16 +49,16 @@ _map = safe_map
|
||||
def identity(x): return x
|
||||
|
||||
def shard_args(backend, devices, assignments, axis_size, tuple_args, args):
|
||||
"""Shard an argument data array arg along its leading axis.
|
||||
"""Shard each argument data array along its leading axis.
|
||||
|
||||
Args:
|
||||
devices: list of Devices of length num_replicas mapping a logical replica
|
||||
index to a physical device.
|
||||
assignments: replica to shard assignment.
|
||||
axis_size: int, size of the axis to be sharded.
|
||||
backend: the platform to be used
|
||||
devices: list of Devices mapping replica index to a physical device.
|
||||
assignments: list of integers with the same length as `devices` mapping
|
||||
replica index to an index along the leading axis (i.e. a shard).
|
||||
axis_size: int, size of the leading axis to be sharded.
|
||||
args: a sequence of JaxTypes representing arguments to be sharded along
|
||||
their leading axes (or the leading axess of their leaves in the tuple
|
||||
case) and placed on `devices`.
|
||||
their leading axes and placed on `devices`.
|
||||
|
||||
Returns:
|
||||
A list of device buffers with the same length as `devices` indexed by
|
||||
@ -72,12 +72,28 @@ def shard_args(backend, devices, assignments, axis_size, tuple_args, args):
|
||||
# inline handling for ShardedDeviceArray as a special case for performance
|
||||
if type(arg) is ShardedDeviceArray:
|
||||
if nrep == len(arg.device_buffers):
|
||||
# The argument is already prepared for the right number of replicas, so
|
||||
# we just ensure that buf[r] is on devices[r] for each replica index r
|
||||
# TODO(mattjj): compared to the other case, this logic has less looping
|
||||
# but could incur more device-to-device data movement
|
||||
for r, buf in enumerate(arg.device_buffers):
|
||||
buffers[r][a] = (buf if buf.device() == devices[r]
|
||||
else buf.copy_to_device(devices[r]))
|
||||
buffers[r][a] = buf if buf.device() == devices[r] else buf.copy_to_device(devices[r])
|
||||
else:
|
||||
# The argument is prepared for a different number of replicas, so for
|
||||
# each of our replica indices we check if there's already a buffer with
|
||||
# the correct logical assignment on the correct device, and if not just
|
||||
# copy one of them
|
||||
prev_assignments = assign_shards_to_replicas(len(arg.device_buffers), axis_size)
|
||||
candidates = defaultdict(list)
|
||||
for r, buf in enumerate(arg.device_buffers):
|
||||
buffers[r][a] = xla.device_put(x[assignments[r]], devices[r], backend=backend)
|
||||
candidates[prev_assignments[r]].append(buf)
|
||||
for r in range(nrep):
|
||||
for buf in candidates[assignments[r]]:
|
||||
if buf.device() == devices[r]:
|
||||
buffers[r][a] = buf
|
||||
break
|
||||
else:
|
||||
buffers[r][a] = buf.copy_to_device(devices[r])
|
||||
else:
|
||||
bufs = shard_arg_handlers[type(arg)](arg, devices, assignments, backend=backend)
|
||||
for r, buf in enumerate(bufs):
|
||||
@ -89,6 +105,7 @@ def shard_args(backend, devices, assignments, axis_size, tuple_args, args):
|
||||
|
||||
return buffers
|
||||
|
||||
|
||||
shard_arg_handlers = {}
|
||||
shard_arg_handlers[core.Unit] = \
|
||||
lambda x, devices, _, backend=None: [
|
||||
|
Loading…
x
Reference in New Issue
Block a user