mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Batch pxla.shard_args
calls triggered by jax.device_put
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree. The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls. PiperOrigin-RevId: 644051624
This commit is contained in:
parent
4ea73bf787
commit
cec796f5dc
@ -2455,24 +2455,25 @@ def device_put(
|
||||
blocking the calling Python thread until any transfers are completed.
|
||||
"""
|
||||
with config.explicit_device_put_scope():
|
||||
if ((device is None or
|
||||
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
|
||||
(src is None or
|
||||
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
|
||||
def _map(y):
|
||||
_check_sharding(shaped_abstractify(y), s=device)
|
||||
return dispatch.device_put_p.bind(
|
||||
y, device=device, src=_infer_src_sharding(src, y))
|
||||
return tree_map(_map, x)
|
||||
|
||||
x_flat, treedef = tree_flatten(x)
|
||||
device_flat = flatten_axes("device_put device", treedef, device)
|
||||
src_flat = flatten_axes("device_put source", treedef, src)
|
||||
out_flat = []
|
||||
for xf, d, s in zip(x_flat, device_flat, src_flat):
|
||||
if (device is None or
|
||||
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))):
|
||||
device_flat = [device] * len(x_flat)
|
||||
else:
|
||||
device_flat = flatten_axes("device_put device", treedef, device)
|
||||
|
||||
if (src is None or
|
||||
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind))):
|
||||
src_flat = [_infer_src_sharding(src, xf) for xf in x_flat]
|
||||
else:
|
||||
src_flat = flatten_axes("device_put source", treedef, src)
|
||||
src_flat = list(map(_infer_src_sharding, src_flat, x_flat))
|
||||
|
||||
for xf, d in zip(x_flat, device_flat):
|
||||
_check_sharding(shaped_abstractify(xf), d)
|
||||
out_flat.append(dispatch.device_put_p.bind(
|
||||
xf, device=d, src=_infer_src_sharding(s, xf)))
|
||||
out_flat = dispatch.device_put_p.bind(
|
||||
*x_flat, devices=device_flat, srcs=src_flat
|
||||
)
|
||||
return tree_unflatten(treedef, out_flat)
|
||||
|
||||
|
||||
|
@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
import atexit
|
||||
from collections.abc import Iterator, Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import itertools
|
||||
import time
|
||||
@ -240,10 +241,9 @@ def jaxpr_shardings(
|
||||
yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
|
||||
for names in [*eqn.params['in_names'], *eqn.params['out_names']])
|
||||
elif eqn.primitive is device_put_p:
|
||||
s = eqn.params['device']
|
||||
if isinstance(s, Sharding) and s.memory_kind is not None:
|
||||
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
||||
yield (s, source_info)
|
||||
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
||||
yield from ((s, source_info) for s in eqn.params['devices']
|
||||
if isinstance(s, Sharding) and s.memory_kind is not None)
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
yield from jaxpr_shardings(subjaxpr)
|
||||
|
||||
@ -322,10 +322,6 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
|
||||
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
||||
|
||||
|
||||
def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool):
|
||||
result_handler = pxla.global_aval_to_result_handler(aval, s, committed)
|
||||
return result_handler(pxla.shard_args([s], [x])[0])
|
||||
|
||||
def _override_get_device_assignment(sharding, *args, **kwargs):
|
||||
da = sharding._device_assignment
|
||||
return xb.get_device_backend(da[0]), da
|
||||
@ -381,6 +377,25 @@ def _mcjax_reshard(x, target_sharding):
|
||||
pxla._get_and_check_device_assignment.fn = _orig_get_and_check_device_assignment
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _DeferredShardArg:
|
||||
"""Deferred call to `pxla.shard_args`.
|
||||
|
||||
Per-array impls return this object instead of a result array to indicate a
|
||||
deferred `shard_args` call. `_batched_device_put_impl` then batches all
|
||||
`_DeferredShardArg` objects into a single `shard_args` call.
|
||||
"""
|
||||
|
||||
x: Any
|
||||
s: Sharding
|
||||
aval: core.AbstractValue
|
||||
committed: bool
|
||||
|
||||
@property
|
||||
def result_handler(self):
|
||||
return pxla.global_aval_to_result_handler(self.aval, self.s, self.committed)
|
||||
|
||||
|
||||
def _device_put_sharding_impl(x, aval, device):
|
||||
from jax._src import array
|
||||
|
||||
@ -402,7 +417,7 @@ def _device_put_sharding_impl(x, aval, device):
|
||||
" trying to use device_put in multi-controller JAX which is not"
|
||||
" supported. Please use jax.make_array_from_single_device_arrays API"
|
||||
" or pass device or Sharding which represents addressable devices.")
|
||||
return _put_x(x, s, aval, True)
|
||||
return _DeferredShardArg(x, s, aval, True)
|
||||
|
||||
# Only `Device` exists below. `Sharding` instance is handled above.
|
||||
if isinstance(x, array.ArrayImpl):
|
||||
@ -418,12 +433,15 @@ def _device_put_sharding_impl(x, aval, device):
|
||||
|
||||
sh = SingleDeviceSharding(pxla._get_default_device()
|
||||
if device is None else device)
|
||||
return _put_x(x, sh, aval, device is not None)
|
||||
return _DeferredShardArg(x, sh, aval, device is not None)
|
||||
|
||||
|
||||
def _device_put_impl(
|
||||
x,
|
||||
device: Device | Sharding | Layout | None = None,
|
||||
src: Device | Sharding | Layout | None = None):
|
||||
*,
|
||||
device: Device | Sharding | Layout | None,
|
||||
src: Device | Sharding | Layout | None,
|
||||
):
|
||||
if (isinstance(device, TransferToMemoryKind) or
|
||||
isinstance(src, TransferToMemoryKind)):
|
||||
raise ValueError(
|
||||
@ -457,43 +475,95 @@ def _device_put_impl(
|
||||
return _device_put_sharding_impl(x, aval, device)
|
||||
|
||||
|
||||
def _batched_device_put_impl(
|
||||
*xs,
|
||||
devices: Sequence[Device | Sharding | Layout | None],
|
||||
srcs: Sequence[Device | Sharding | Layout | None],
|
||||
):
|
||||
ys = []
|
||||
shard_arg_indices, shard_arg_xs, shard_arg_shardings = [], [], []
|
||||
for i, (x, device, src) in enumerate(zip(xs, devices, srcs)):
|
||||
y = _device_put_impl(x, device=device, src=src)
|
||||
if isinstance(y, _DeferredShardArg):
|
||||
shard_arg_indices.append(i)
|
||||
shard_arg_xs.append(y.x)
|
||||
shard_arg_shardings.append(y.s)
|
||||
ys.append(y)
|
||||
|
||||
if shard_arg_xs:
|
||||
# Batch shard_arg calls. Helps improve efficiency for backends that support
|
||||
# efficient batch transfer.
|
||||
shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs)
|
||||
for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results):
|
||||
assert isinstance(ys[i], _DeferredShardArg)
|
||||
ys[i] = ys[i].result_handler(shard_arg_result)
|
||||
|
||||
return ys
|
||||
|
||||
|
||||
device_put_p = core.Primitive('device_put')
|
||||
device_put_p.def_impl(_device_put_impl)
|
||||
device_put_p.def_abstract_eval(lambda x, device=None, src=None: x)
|
||||
device_put_p.multiple_results = True
|
||||
device_put_p.def_impl(_batched_device_put_impl)
|
||||
device_put_p.def_abstract_eval(lambda *xs, devices, srcs: xs)
|
||||
|
||||
def device_put_transpose_rule(ct, _, device, src):
|
||||
return [device_put_p.bind(ct, device=src, src=device)]
|
||||
ad.deflinear2(device_put_p, device_put_transpose_rule)
|
||||
batching.defvectorized(device_put_p)
|
||||
def _device_put_transpose(cts, *_, devices, srcs):
|
||||
results = [None] * len(cts)
|
||||
dp_args = []
|
||||
for i, (ct, device, src) in enumerate(zip(cts, devices, srcs)):
|
||||
if type(ct) is not ad.Zero:
|
||||
dp_args.append((i, ct, device, src))
|
||||
if dp_args:
|
||||
indices, args, devices, srcs = list(zip(*dp_args))
|
||||
ys = device_put_p.bind(*args, devices=srcs, srcs=devices)
|
||||
for i, y in zip(indices, ys):
|
||||
results[i] = y
|
||||
return results
|
||||
ad.primitive_jvps[device_put_p] = partial(ad.linear_jvp, device_put_p)
|
||||
ad.primitive_transposes[device_put_p] = _device_put_transpose
|
||||
|
||||
def _tpu_gpu_device_put_lowering(ctx, x, *, device, src):
|
||||
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
|
||||
device.memory_kind is not None):
|
||||
aval, = ctx.avals_in
|
||||
out_aval, = ctx.avals_out
|
||||
if isinstance(device, Sharding):
|
||||
x = mlir.wrap_with_sharding_op(
|
||||
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
|
||||
return [x]
|
||||
return [x]
|
||||
def _device_put_batcher(batched_args, batch_dims, **params):
|
||||
mapped_batch_dims = [bd for bd in batch_dims if bd is not batching.not_mapped]
|
||||
assert not mapped_batch_dims or all(
|
||||
mapped_batch_dims[0] == bd for bd in mapped_batch_dims[1:]
|
||||
), batch_dims
|
||||
return device_put_p.bind(*batched_args, **params), batch_dims
|
||||
batching.primitive_batchers[device_put_p] = _device_put_batcher
|
||||
|
||||
def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs):
|
||||
def lower(x, device, src):
|
||||
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
|
||||
device.memory_kind is not None):
|
||||
aval, = ctx.avals_in
|
||||
out_aval, = ctx.avals_out
|
||||
if isinstance(device, Sharding):
|
||||
x = mlir.wrap_with_sharding_op(
|
||||
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
|
||||
return x
|
||||
return x
|
||||
return list(map(lower, xs, devices, srcs))
|
||||
mlir.register_lowering(
|
||||
device_put_p, _tpu_gpu_device_put_lowering, platform='tpu')
|
||||
mlir.register_lowering(
|
||||
device_put_p, _tpu_gpu_device_put_lowering, platform='gpu')
|
||||
|
||||
|
||||
def _common_device_put_lowering(ctx, x, *, device, src):
|
||||
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
|
||||
device.memory_kind is not None):
|
||||
raise NotImplementedError(
|
||||
"Passing memory_kind to device_put via Shardings is not supported on"
|
||||
f" platforms {ctx.module_context.platforms}")
|
||||
return [x]
|
||||
def _common_device_put_lowering(ctx, *xs, devices, srcs):
|
||||
for device in devices:
|
||||
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
|
||||
device.memory_kind is not None):
|
||||
raise NotImplementedError(
|
||||
"Passing memory_kind to device_put via Shardings is not supported on"
|
||||
f" platforms {ctx.module_context.platforms}")
|
||||
return xs
|
||||
mlir.register_lowering(device_put_p, _common_device_put_lowering)
|
||||
|
||||
def _propagate_mem_kind_dp(xm, device=None, src=None):
|
||||
if isinstance(device, (Sharding, TransferToMemoryKind)):
|
||||
return device.memory_kind
|
||||
return None
|
||||
def _propagate_mem_kind_dp(*xm, devices=None, srcs=None):
|
||||
memory_kinds = []
|
||||
for device in devices:
|
||||
if isinstance(device, (Sharding, TransferToMemoryKind)):
|
||||
memory_kinds.append(device.memory_kind)
|
||||
else:
|
||||
memory_kinds.append(None)
|
||||
return memory_kinds
|
||||
pxla.memory_kind_propagate_rule[device_put_p] = _propagate_mem_kind_dp
|
||||
|
@ -654,7 +654,7 @@ def _make_device_put_harness(name,
|
||||
define(
|
||||
"device_put",
|
||||
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{device=}",
|
||||
lambda x: dispatch.device_put_p.bind(x, device=_device_fn(), src=None),
|
||||
lambda x: dispatch.device_put_p.bind(x, devices=[_device_fn()], srcs=[None])[0],
|
||||
[RandArg(shape, dtype)],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
|
@ -1260,7 +1260,7 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
outvars_copy = list[Atom](eqn.outvars)
|
||||
offload_eqn = core.JaxprEqn(
|
||||
outvars_copy, resvars, device_put_p,
|
||||
dict(device=TransferToMemoryKind(policy.dst), src=None),
|
||||
dict(devices=[TransferToMemoryKind(policy.dst)], srcs=[None]),
|
||||
set(), source_info_util.new_source_info(),
|
||||
JaxprEqnContext(None, False))
|
||||
known_eqns.append(offload_eqn)
|
||||
@ -1269,7 +1269,7 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
residuals.update(resvars)
|
||||
reload_eqn = core.JaxprEqn(
|
||||
resvars, eqn.outvars, device_put_p, # type: ignore
|
||||
dict(device=TransferToMemoryKind(policy.src), src=None),
|
||||
dict(devices=[TransferToMemoryKind(policy.src)], srcs=[None]),
|
||||
set(), source_info_util.new_source_info(),
|
||||
JaxprEqnContext(None, False))
|
||||
staged_eqns.append(reload_eqn)
|
||||
|
@ -2051,9 +2051,9 @@ def _create_da_object( # pytype: disable=invalid-annotation
|
||||
def jaxpr_transfer_mem_kinds(
|
||||
jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]:
|
||||
for eqn in jaxpr.eqns:
|
||||
if (eqn.primitive is dispatch.device_put_p and
|
||||
isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)):
|
||||
yield eqn.params['device']
|
||||
if eqn.primitive is dispatch.device_put_p:
|
||||
yield from (d for d in eqn.params['devices']
|
||||
if isinstance(d, sharding_impls.TransferToMemoryKind))
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
yield from jaxpr_transfer_mem_kinds(subjaxpr)
|
||||
|
||||
|
@ -679,12 +679,10 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
|
||||
def _maybe_put(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return dispatch._put_x(
|
||||
x,
|
||||
jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]),
|
||||
shaped_abstractify(x),
|
||||
False,
|
||||
)
|
||||
aval = shaped_abstractify(x)
|
||||
s = jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0])
|
||||
result_handler = pxla.global_aval_to_result_handler(aval, s, False)
|
||||
return result_handler(pxla.shard_args([s], [x]))
|
||||
else:
|
||||
return x
|
||||
|
||||
|
@ -1547,7 +1547,7 @@ def _add(x: TfVal, y: TfVal) -> TfVal:
|
||||
|
||||
|
||||
tf_impl[ad_util.add_jaxvals_p] = _add
|
||||
tf_impl[dispatch.device_put_p] = lambda x, device=None, src=None: x
|
||||
tf_impl[dispatch.device_put_p] = lambda *xs, devices=None, srcs=None: xs
|
||||
tf_impl[lax_internal.copy_p] = lambda x: x
|
||||
|
||||
def _shard_alike(*args: TfVal, **_):
|
||||
|
@ -896,13 +896,13 @@ def _debug_callback_eager_rule(mesh, *args, callback: Callable[..., Any],
|
||||
return []
|
||||
eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule
|
||||
|
||||
def _device_put_eager_rule(mesh, x, *, src, device):
|
||||
del mesh, src
|
||||
if device is None:
|
||||
return x
|
||||
else:
|
||||
raise ValueError("device_put with explicit device not allowed within "
|
||||
f"shard_map-decorated functions, but got device {device}")
|
||||
def _device_put_eager_rule(mesh, *xs, srcs, devices):
|
||||
del mesh, srcs
|
||||
for device in devices:
|
||||
if device is not None:
|
||||
raise ValueError("device_put with explicit device not allowed within "
|
||||
f"shard_map-decorated functions, but got device {device}")
|
||||
return xs
|
||||
eager_rules[dispatch.device_put_p] = _device_put_eager_rule
|
||||
|
||||
# New primitives for efficient transposition
|
||||
@ -1145,8 +1145,8 @@ register_norewrite(callback.io_callback_p)
|
||||
|
||||
|
||||
@register_check(dispatch.device_put_p)
|
||||
def _device_put_rule(mesh, x, **_):
|
||||
return x
|
||||
def _device_put_rule(mesh, *xs, **_):
|
||||
return list(xs)
|
||||
register_norewrite(dispatch.device_put_p)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user