Batch pxla.shard_args calls triggered by jax.device_put

With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.

The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.

PiperOrigin-RevId: 644051624
This commit is contained in:
Junwhan Ahn 2024-06-17 10:16:38 -07:00 committed by jax authors
parent 4ea73bf787
commit cec796f5dc
8 changed files with 147 additions and 78 deletions

View File

@ -2455,24 +2455,25 @@ def device_put(
blocking the calling Python thread until any transfers are completed.
"""
with config.explicit_device_put_scope():
if ((device is None or
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
(src is None or
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
def _map(y):
_check_sharding(shaped_abstractify(y), s=device)
return dispatch.device_put_p.bind(
y, device=device, src=_infer_src_sharding(src, y))
return tree_map(_map, x)
x_flat, treedef = tree_flatten(x)
device_flat = flatten_axes("device_put device", treedef, device)
src_flat = flatten_axes("device_put source", treedef, src)
out_flat = []
for xf, d, s in zip(x_flat, device_flat, src_flat):
if (device is None or
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))):
device_flat = [device] * len(x_flat)
else:
device_flat = flatten_axes("device_put device", treedef, device)
if (src is None or
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind))):
src_flat = [_infer_src_sharding(src, xf) for xf in x_flat]
else:
src_flat = flatten_axes("device_put source", treedef, src)
src_flat = list(map(_infer_src_sharding, src_flat, x_flat))
for xf, d in zip(x_flat, device_flat):
_check_sharding(shaped_abstractify(xf), d)
out_flat.append(dispatch.device_put_p.bind(
xf, device=d, src=_infer_src_sharding(s, xf)))
out_flat = dispatch.device_put_p.bind(
*x_flat, devices=device_flat, srcs=src_flat
)
return tree_unflatten(treedef, out_flat)

View File

@ -18,6 +18,7 @@ from __future__ import annotations
import atexit
from collections.abc import Iterator, Sequence
import contextlib
import dataclasses
from functools import partial
import itertools
import time
@ -240,10 +241,9 @@ def jaxpr_shardings(
yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
for names in [*eqn.params['in_names'], *eqn.params['out_names']])
elif eqn.primitive is device_put_p:
s = eqn.params['device']
if isinstance(s, Sharding) and s.memory_kind is not None:
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
yield (s, source_info)
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
yield from ((s, source_info) for s in eqn.params['devices']
if isinstance(s, Sharding) and s.memory_kind is not None)
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_shardings(subjaxpr)
@ -322,10 +322,6 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool):
result_handler = pxla.global_aval_to_result_handler(aval, s, committed)
return result_handler(pxla.shard_args([s], [x])[0])
def _override_get_device_assignment(sharding, *args, **kwargs):
da = sharding._device_assignment
return xb.get_device_backend(da[0]), da
@ -381,6 +377,25 @@ def _mcjax_reshard(x, target_sharding):
pxla._get_and_check_device_assignment.fn = _orig_get_and_check_device_assignment
@dataclasses.dataclass(frozen=True)
class _DeferredShardArg:
"""Deferred call to `pxla.shard_args`.
Per-array impls return this object instead of a result array to indicate a
deferred `shard_args` call. `_batched_device_put_impl` then batches all
`_DeferredShardArg` objects into a single `shard_args` call.
"""
x: Any
s: Sharding
aval: core.AbstractValue
committed: bool
@property
def result_handler(self):
return pxla.global_aval_to_result_handler(self.aval, self.s, self.committed)
def _device_put_sharding_impl(x, aval, device):
from jax._src import array
@ -402,7 +417,7 @@ def _device_put_sharding_impl(x, aval, device):
" trying to use device_put in multi-controller JAX which is not"
" supported. Please use jax.make_array_from_single_device_arrays API"
" or pass device or Sharding which represents addressable devices.")
return _put_x(x, s, aval, True)
return _DeferredShardArg(x, s, aval, True)
# Only `Device` exists below. `Sharding` instance is handled above.
if isinstance(x, array.ArrayImpl):
@ -418,12 +433,15 @@ def _device_put_sharding_impl(x, aval, device):
sh = SingleDeviceSharding(pxla._get_default_device()
if device is None else device)
return _put_x(x, sh, aval, device is not None)
return _DeferredShardArg(x, sh, aval, device is not None)
def _device_put_impl(
x,
device: Device | Sharding | Layout | None = None,
src: Device | Sharding | Layout | None = None):
*,
device: Device | Sharding | Layout | None,
src: Device | Sharding | Layout | None,
):
if (isinstance(device, TransferToMemoryKind) or
isinstance(src, TransferToMemoryKind)):
raise ValueError(
@ -457,43 +475,95 @@ def _device_put_impl(
return _device_put_sharding_impl(x, aval, device)
def _batched_device_put_impl(
*xs,
devices: Sequence[Device | Sharding | Layout | None],
srcs: Sequence[Device | Sharding | Layout | None],
):
ys = []
shard_arg_indices, shard_arg_xs, shard_arg_shardings = [], [], []
for i, (x, device, src) in enumerate(zip(xs, devices, srcs)):
y = _device_put_impl(x, device=device, src=src)
if isinstance(y, _DeferredShardArg):
shard_arg_indices.append(i)
shard_arg_xs.append(y.x)
shard_arg_shardings.append(y.s)
ys.append(y)
if shard_arg_xs:
# Batch shard_arg calls. Helps improve efficiency for backends that support
# efficient batch transfer.
shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs)
for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results):
assert isinstance(ys[i], _DeferredShardArg)
ys[i] = ys[i].result_handler(shard_arg_result)
return ys
device_put_p = core.Primitive('device_put')
device_put_p.def_impl(_device_put_impl)
device_put_p.def_abstract_eval(lambda x, device=None, src=None: x)
device_put_p.multiple_results = True
device_put_p.def_impl(_batched_device_put_impl)
device_put_p.def_abstract_eval(lambda *xs, devices, srcs: xs)
def device_put_transpose_rule(ct, _, device, src):
return [device_put_p.bind(ct, device=src, src=device)]
ad.deflinear2(device_put_p, device_put_transpose_rule)
batching.defvectorized(device_put_p)
def _device_put_transpose(cts, *_, devices, srcs):
results = [None] * len(cts)
dp_args = []
for i, (ct, device, src) in enumerate(zip(cts, devices, srcs)):
if type(ct) is not ad.Zero:
dp_args.append((i, ct, device, src))
if dp_args:
indices, args, devices, srcs = list(zip(*dp_args))
ys = device_put_p.bind(*args, devices=srcs, srcs=devices)
for i, y in zip(indices, ys):
results[i] = y
return results
ad.primitive_jvps[device_put_p] = partial(ad.linear_jvp, device_put_p)
ad.primitive_transposes[device_put_p] = _device_put_transpose
def _tpu_gpu_device_put_lowering(ctx, x, *, device, src):
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
device.memory_kind is not None):
aval, = ctx.avals_in
out_aval, = ctx.avals_out
if isinstance(device, Sharding):
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
return [x]
return [x]
def _device_put_batcher(batched_args, batch_dims, **params):
mapped_batch_dims = [bd for bd in batch_dims if bd is not batching.not_mapped]
assert not mapped_batch_dims or all(
mapped_batch_dims[0] == bd for bd in mapped_batch_dims[1:]
), batch_dims
return device_put_p.bind(*batched_args, **params), batch_dims
batching.primitive_batchers[device_put_p] = _device_put_batcher
def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs):
def lower(x, device, src):
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
device.memory_kind is not None):
aval, = ctx.avals_in
out_aval, = ctx.avals_out
if isinstance(device, Sharding):
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
return x
return x
return list(map(lower, xs, devices, srcs))
mlir.register_lowering(
device_put_p, _tpu_gpu_device_put_lowering, platform='tpu')
mlir.register_lowering(
device_put_p, _tpu_gpu_device_put_lowering, platform='gpu')
def _common_device_put_lowering(ctx, x, *, device, src):
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
device.memory_kind is not None):
raise NotImplementedError(
"Passing memory_kind to device_put via Shardings is not supported on"
f" platforms {ctx.module_context.platforms}")
return [x]
def _common_device_put_lowering(ctx, *xs, devices, srcs):
for device in devices:
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
device.memory_kind is not None):
raise NotImplementedError(
"Passing memory_kind to device_put via Shardings is not supported on"
f" platforms {ctx.module_context.platforms}")
return xs
mlir.register_lowering(device_put_p, _common_device_put_lowering)
def _propagate_mem_kind_dp(xm, device=None, src=None):
if isinstance(device, (Sharding, TransferToMemoryKind)):
return device.memory_kind
return None
def _propagate_mem_kind_dp(*xm, devices=None, srcs=None):
memory_kinds = []
for device in devices:
if isinstance(device, (Sharding, TransferToMemoryKind)):
memory_kinds.append(device.memory_kind)
else:
memory_kinds.append(None)
return memory_kinds
pxla.memory_kind_propagate_rule[device_put_p] = _propagate_mem_kind_dp

View File

@ -654,7 +654,7 @@ def _make_device_put_harness(name,
define(
"device_put",
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{device=}",
lambda x: dispatch.device_put_p.bind(x, device=_device_fn(), src=None),
lambda x: dispatch.device_put_p.bind(x, devices=[_device_fn()], srcs=[None])[0],
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,

View File

@ -1260,7 +1260,7 @@ def _partial_eval_jaxpr_custom_cached(
outvars_copy = list[Atom](eqn.outvars)
offload_eqn = core.JaxprEqn(
outvars_copy, resvars, device_put_p,
dict(device=TransferToMemoryKind(policy.dst), src=None),
dict(devices=[TransferToMemoryKind(policy.dst)], srcs=[None]),
set(), source_info_util.new_source_info(),
JaxprEqnContext(None, False))
known_eqns.append(offload_eqn)
@ -1269,7 +1269,7 @@ def _partial_eval_jaxpr_custom_cached(
residuals.update(resvars)
reload_eqn = core.JaxprEqn(
resvars, eqn.outvars, device_put_p, # type: ignore
dict(device=TransferToMemoryKind(policy.src), src=None),
dict(devices=[TransferToMemoryKind(policy.src)], srcs=[None]),
set(), source_info_util.new_source_info(),
JaxprEqnContext(None, False))
staged_eqns.append(reload_eqn)

View File

@ -2051,9 +2051,9 @@ def _create_da_object( # pytype: disable=invalid-annotation
def jaxpr_transfer_mem_kinds(
jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]:
for eqn in jaxpr.eqns:
if (eqn.primitive is dispatch.device_put_p and
isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)):
yield eqn.params['device']
if eqn.primitive is dispatch.device_put_p:
yield from (d for d in eqn.params['devices']
if isinstance(d, sharding_impls.TransferToMemoryKind))
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_transfer_mem_kinds(subjaxpr)

View File

@ -679,12 +679,10 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
def _maybe_put(x):
if isinstance(x, np.ndarray):
return dispatch._put_x(
x,
jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]),
shaped_abstractify(x),
False,
)
aval = shaped_abstractify(x)
s = jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0])
result_handler = pxla.global_aval_to_result_handler(aval, s, False)
return result_handler(pxla.shard_args([s], [x]))
else:
return x

View File

@ -1547,7 +1547,7 @@ def _add(x: TfVal, y: TfVal) -> TfVal:
tf_impl[ad_util.add_jaxvals_p] = _add
tf_impl[dispatch.device_put_p] = lambda x, device=None, src=None: x
tf_impl[dispatch.device_put_p] = lambda *xs, devices=None, srcs=None: xs
tf_impl[lax_internal.copy_p] = lambda x: x
def _shard_alike(*args: TfVal, **_):

View File

@ -896,13 +896,13 @@ def _debug_callback_eager_rule(mesh, *args, callback: Callable[..., Any],
return []
eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule
def _device_put_eager_rule(mesh, x, *, src, device):
del mesh, src
if device is None:
return x
else:
raise ValueError("device_put with explicit device not allowed within "
f"shard_map-decorated functions, but got device {device}")
def _device_put_eager_rule(mesh, *xs, srcs, devices):
del mesh, srcs
for device in devices:
if device is not None:
raise ValueError("device_put with explicit device not allowed within "
f"shard_map-decorated functions, but got device {device}")
return xs
eager_rules[dispatch.device_put_p] = _device_put_eager_rule
# New primitives for efficient transposition
@ -1145,8 +1145,8 @@ register_norewrite(callback.io_callback_p)
@register_check(dispatch.device_put_p)
def _device_put_rule(mesh, x, **_):
return x
def _device_put_rule(mesh, *xs, **_):
return list(xs)
register_norewrite(dispatch.device_put_p)