fix shard_args logic, closes #1688

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
This commit is contained in:
Matthew Johnson 2019-11-14 16:15:50 -08:00
parent ec89eb9e5f
commit c19e65b7ab

View File

@ -16,7 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
from collections import namedtuple, defaultdict
from contextlib import contextmanager
import itertools as it
import operator as op
@ -49,16 +49,16 @@ _map = safe_map
def identity(x): return x
def shard_args(backend, devices, assignments, axis_size, tuple_args, args):
"""Shard an argument data array arg along its leading axis.
"""Shard each argument data array along its leading axis.
Args:
devices: list of Devices of length num_replicas mapping a logical replica
index to a physical device.
assignments: replica to shard assignment.
axis_size: int, size of the axis to be sharded.
backend: the platform to be used
devices: list of Devices mapping replica index to a physical device.
assignments: list of integers with the same length as `devices` mapping
replica index to an index along the leading axis (i.e. a shard).
axis_size: int, size of the leading axis to be sharded.
args: a sequence of JaxTypes representing arguments to be sharded along
their leading axes (or the leading axess of their leaves in the tuple
case) and placed on `devices`.
their leading axes and placed on `devices`.
Returns:
A list of device buffers with the same length as `devices` indexed by
@ -72,12 +72,28 @@ def shard_args(backend, devices, assignments, axis_size, tuple_args, args):
# inline handling for ShardedDeviceArray as a special case for performance
if type(arg) is ShardedDeviceArray:
if nrep == len(arg.device_buffers):
# The argument is already prepared for the right number of replicas, so
# we just ensure that buf[r] is on devices[r] for each replica index r
# TODO(mattjj): compared to the other case, this logic has less looping
# but could incur more device-to-device data movement
for r, buf in enumerate(arg.device_buffers):
buffers[r][a] = (buf if buf.device() == devices[r]
else buf.copy_to_device(devices[r]))
buffers[r][a] = buf if buf.device() == devices[r] else buf.copy_to_device(devices[r])
else:
# The argument is prepared for a different number of replicas, so for
# each of our replica indices we check if there's already a buffer with
# the correct logical assignment on the correct device, and if not just
# copy one of them
prev_assignments = assign_shards_to_replicas(len(arg.device_buffers), axis_size)
candidates = defaultdict(list)
for r, buf in enumerate(arg.device_buffers):
buffers[r][a] = xla.device_put(x[assignments[r]], devices[r], backend=backend)
candidates[prev_assignments[r]].append(buf)
for r in range(nrep):
for buf in candidates[assignments[r]]:
if buf.device() == devices[r]:
buffers[r][a] = buf
break
else:
buffers[r][a] = buf.copy_to_device(devices[r])
else:
bufs = shard_arg_handlers[type(arg)](arg, devices, assignments, backend=backend)
for r, buf in enumerate(bufs):
@ -89,6 +105,7 @@ def shard_args(backend, devices, assignments, axis_size, tuple_args, args):
return buffers
shard_arg_handlers = {}
shard_arg_handlers[core.Unit] = \
lambda x, devices, _, backend=None: [