diff --git a/jax/_src/api.py b/jax/_src/api.py index 161390688..3b1fdd6b3 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2455,24 +2455,25 @@ def device_put( blocking the calling Python thread until any transfers are completed. """ with config.explicit_device_put_scope(): - if ((device is None or - isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and - (src is None or - isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))): - def _map(y): - _check_sharding(shaped_abstractify(y), s=device) - return dispatch.device_put_p.bind( - y, device=device, src=_infer_src_sharding(src, y)) - return tree_map(_map, x) - x_flat, treedef = tree_flatten(x) - device_flat = flatten_axes("device_put device", treedef, device) - src_flat = flatten_axes("device_put source", treedef, src) - out_flat = [] - for xf, d, s in zip(x_flat, device_flat, src_flat): + if (device is None or + isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))): + device_flat = [device] * len(x_flat) + else: + device_flat = flatten_axes("device_put device", treedef, device) + + if (src is None or + isinstance(src, (xc.Device, Sharding, TransferToMemoryKind))): + src_flat = [_infer_src_sharding(src, xf) for xf in x_flat] + else: + src_flat = flatten_axes("device_put source", treedef, src) + src_flat = list(map(_infer_src_sharding, src_flat, x_flat)) + + for xf, d in zip(x_flat, device_flat): _check_sharding(shaped_abstractify(xf), d) - out_flat.append(dispatch.device_put_p.bind( - xf, device=d, src=_infer_src_sharding(s, xf))) + out_flat = dispatch.device_put_p.bind( + *x_flat, devices=device_flat, srcs=src_flat + ) return tree_unflatten(treedef, out_flat) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 9f2fae105..d21963d70 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -18,6 +18,7 @@ from __future__ import annotations import atexit from collections.abc import Iterator, Sequence import contextlib +import dataclasses from functools import partial import itertools import time @@ -240,10 +241,9 @@ def jaxpr_shardings( yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info) for names in [*eqn.params['in_names'], *eqn.params['out_names']]) elif eqn.primitive is device_put_p: - s = eqn.params['device'] - if isinstance(s, Sharding) and s.memory_kind is not None: - source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - yield (s, source_info) + source_info = SourceInfo(eqn.source_info, eqn.primitive.name) + yield from ((s, source_info) for s in eqn.params['devices'] + if isinstance(s, Sharding) and s.memory_kind is not None) for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_shardings(subjaxpr) @@ -322,10 +322,6 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: raise FloatingPointError(f"invalid value (inf) encountered in {name}") -def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool): - result_handler = pxla.global_aval_to_result_handler(aval, s, committed) - return result_handler(pxla.shard_args([s], [x])[0]) - def _override_get_device_assignment(sharding, *args, **kwargs): da = sharding._device_assignment return xb.get_device_backend(da[0]), da @@ -381,6 +377,25 @@ def _mcjax_reshard(x, target_sharding): pxla._get_and_check_device_assignment.fn = _orig_get_and_check_device_assignment +@dataclasses.dataclass(frozen=True) +class _DeferredShardArg: + """Deferred call to `pxla.shard_args`. + + Per-array impls return this object instead of a result array to indicate a + deferred `shard_args` call. `_batched_device_put_impl` then batches all + `_DeferredShardArg` objects into a single `shard_args` call. + """ + + x: Any + s: Sharding + aval: core.AbstractValue + committed: bool + + @property + def result_handler(self): + return pxla.global_aval_to_result_handler(self.aval, self.s, self.committed) + + def _device_put_sharding_impl(x, aval, device): from jax._src import array @@ -402,7 +417,7 @@ def _device_put_sharding_impl(x, aval, device): " trying to use device_put in multi-controller JAX which is not" " supported. Please use jax.make_array_from_single_device_arrays API" " or pass device or Sharding which represents addressable devices.") - return _put_x(x, s, aval, True) + return _DeferredShardArg(x, s, aval, True) # Only `Device` exists below. `Sharding` instance is handled above. if isinstance(x, array.ArrayImpl): @@ -418,12 +433,15 @@ def _device_put_sharding_impl(x, aval, device): sh = SingleDeviceSharding(pxla._get_default_device() if device is None else device) - return _put_x(x, sh, aval, device is not None) + return _DeferredShardArg(x, sh, aval, device is not None) + def _device_put_impl( x, - device: Device | Sharding | Layout | None = None, - src: Device | Sharding | Layout | None = None): + *, + device: Device | Sharding | Layout | None, + src: Device | Sharding | Layout | None, +): if (isinstance(device, TransferToMemoryKind) or isinstance(src, TransferToMemoryKind)): raise ValueError( @@ -457,43 +475,95 @@ def _device_put_impl( return _device_put_sharding_impl(x, aval, device) +def _batched_device_put_impl( + *xs, + devices: Sequence[Device | Sharding | Layout | None], + srcs: Sequence[Device | Sharding | Layout | None], +): + ys = [] + shard_arg_indices, shard_arg_xs, shard_arg_shardings = [], [], [] + for i, (x, device, src) in enumerate(zip(xs, devices, srcs)): + y = _device_put_impl(x, device=device, src=src) + if isinstance(y, _DeferredShardArg): + shard_arg_indices.append(i) + shard_arg_xs.append(y.x) + shard_arg_shardings.append(y.s) + ys.append(y) + + if shard_arg_xs: + # Batch shard_arg calls. Helps improve efficiency for backends that support + # efficient batch transfer. + shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs) + for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results): + assert isinstance(ys[i], _DeferredShardArg) + ys[i] = ys[i].result_handler(shard_arg_result) + + return ys + + device_put_p = core.Primitive('device_put') -device_put_p.def_impl(_device_put_impl) -device_put_p.def_abstract_eval(lambda x, device=None, src=None: x) +device_put_p.multiple_results = True +device_put_p.def_impl(_batched_device_put_impl) +device_put_p.def_abstract_eval(lambda *xs, devices, srcs: xs) -def device_put_transpose_rule(ct, _, device, src): - return [device_put_p.bind(ct, device=src, src=device)] -ad.deflinear2(device_put_p, device_put_transpose_rule) -batching.defvectorized(device_put_p) +def _device_put_transpose(cts, *_, devices, srcs): + results = [None] * len(cts) + dp_args = [] + for i, (ct, device, src) in enumerate(zip(cts, devices, srcs)): + if type(ct) is not ad.Zero: + dp_args.append((i, ct, device, src)) + if dp_args: + indices, args, devices, srcs = list(zip(*dp_args)) + ys = device_put_p.bind(*args, devices=srcs, srcs=devices) + for i, y in zip(indices, ys): + results[i] = y + return results +ad.primitive_jvps[device_put_p] = partial(ad.linear_jvp, device_put_p) +ad.primitive_transposes[device_put_p] = _device_put_transpose -def _tpu_gpu_device_put_lowering(ctx, x, *, device, src): - if (isinstance(device, (Sharding, TransferToMemoryKind)) and - device.memory_kind is not None): - aval, = ctx.avals_in - out_aval, = ctx.avals_out - if isinstance(device, Sharding): - x = mlir.wrap_with_sharding_op( - ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) - x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval) - return [x] - return [x] +def _device_put_batcher(batched_args, batch_dims, **params): + mapped_batch_dims = [bd for bd in batch_dims if bd is not batching.not_mapped] + assert not mapped_batch_dims or all( + mapped_batch_dims[0] == bd for bd in mapped_batch_dims[1:] + ), batch_dims + return device_put_p.bind(*batched_args, **params), batch_dims +batching.primitive_batchers[device_put_p] = _device_put_batcher + +def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs): + def lower(x, device, src): + if (isinstance(device, (Sharding, TransferToMemoryKind)) and + device.memory_kind is not None): + aval, = ctx.avals_in + out_aval, = ctx.avals_out + if isinstance(device, Sharding): + x = mlir.wrap_with_sharding_op( + ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) + x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval) + return x + return x + return list(map(lower, xs, devices, srcs)) mlir.register_lowering( device_put_p, _tpu_gpu_device_put_lowering, platform='tpu') mlir.register_lowering( device_put_p, _tpu_gpu_device_put_lowering, platform='gpu') -def _common_device_put_lowering(ctx, x, *, device, src): - if (isinstance(device, (Sharding, TransferToMemoryKind)) and - device.memory_kind is not None): - raise NotImplementedError( - "Passing memory_kind to device_put via Shardings is not supported on" - f" platforms {ctx.module_context.platforms}") - return [x] +def _common_device_put_lowering(ctx, *xs, devices, srcs): + for device in devices: + if (isinstance(device, (Sharding, TransferToMemoryKind)) and + device.memory_kind is not None): + raise NotImplementedError( + "Passing memory_kind to device_put via Shardings is not supported on" + f" platforms {ctx.module_context.platforms}") + return xs mlir.register_lowering(device_put_p, _common_device_put_lowering) -def _propagate_mem_kind_dp(xm, device=None, src=None): - if isinstance(device, (Sharding, TransferToMemoryKind)): - return device.memory_kind - return None +def _propagate_mem_kind_dp(*xm, devices=None, srcs=None): + memory_kinds = [] + for device in devices: + if isinstance(device, (Sharding, TransferToMemoryKind)): + memory_kinds.append(device.memory_kind) + else: + memory_kinds.append(None) + return memory_kinds pxla.memory_kind_propagate_rule[device_put_p] = _propagate_mem_kind_dp diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index c8b8d203a..44b678d7e 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -654,7 +654,7 @@ def _make_device_put_harness(name, define( "device_put", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{device=}", - lambda x: dispatch.device_put_p.bind(x, device=_device_fn(), src=None), + lambda x: dispatch.device_put_p.bind(x, devices=[_device_fn()], srcs=[None])[0], [RandArg(shape, dtype)], shape=shape, dtype=dtype, diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index d3f17328e..830ec1f9c 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1260,7 +1260,7 @@ def _partial_eval_jaxpr_custom_cached( outvars_copy = list[Atom](eqn.outvars) offload_eqn = core.JaxprEqn( outvars_copy, resvars, device_put_p, - dict(device=TransferToMemoryKind(policy.dst), src=None), + dict(devices=[TransferToMemoryKind(policy.dst)], srcs=[None]), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) known_eqns.append(offload_eqn) @@ -1269,7 +1269,7 @@ def _partial_eval_jaxpr_custom_cached( residuals.update(resvars) reload_eqn = core.JaxprEqn( resvars, eqn.outvars, device_put_p, # type: ignore - dict(device=TransferToMemoryKind(policy.src), src=None), + dict(devices=[TransferToMemoryKind(policy.src)], srcs=[None]), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) staged_eqns.append(reload_eqn) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7e9d15c7c..d795abcde 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2051,9 +2051,9 @@ def _create_da_object( # pytype: disable=invalid-annotation def jaxpr_transfer_mem_kinds( jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]: for eqn in jaxpr.eqns: - if (eqn.primitive is dispatch.device_put_p and - isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)): - yield eqn.params['device'] + if eqn.primitive is dispatch.device_put_p: + yield from (d for d in eqn.params['devices'] + if isinstance(d, sharding_impls.TransferToMemoryKind)) for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_transfer_mem_kinds(subjaxpr) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 8e372d8ce..3bfb60c6e 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -679,12 +679,10 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, def _maybe_put(x): if isinstance(x, np.ndarray): - return dispatch._put_x( - x, - jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]), - shaped_abstractify(x), - False, - ) + aval = shaped_abstractify(x) + s = jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]) + result_handler = pxla.global_aval_to_result_handler(aval, s, False) + return result_handler(pxla.shard_args([s], [x])) else: return x diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 27073bcae..b53ffda9c 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1547,7 +1547,7 @@ def _add(x: TfVal, y: TfVal) -> TfVal: tf_impl[ad_util.add_jaxvals_p] = _add -tf_impl[dispatch.device_put_p] = lambda x, device=None, src=None: x +tf_impl[dispatch.device_put_p] = lambda *xs, devices=None, srcs=None: xs tf_impl[lax_internal.copy_p] = lambda x: x def _shard_alike(*args: TfVal, **_): diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 2fabdcf38..76cfb266a 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -896,13 +896,13 @@ def _debug_callback_eager_rule(mesh, *args, callback: Callable[..., Any], return [] eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule -def _device_put_eager_rule(mesh, x, *, src, device): - del mesh, src - if device is None: - return x - else: - raise ValueError("device_put with explicit device not allowed within " - f"shard_map-decorated functions, but got device {device}") +def _device_put_eager_rule(mesh, *xs, srcs, devices): + del mesh, srcs + for device in devices: + if device is not None: + raise ValueError("device_put with explicit device not allowed within " + f"shard_map-decorated functions, but got device {device}") + return xs eager_rules[dispatch.device_put_p] = _device_put_eager_rule # New primitives for efficient transposition @@ -1145,8 +1145,8 @@ register_norewrite(callback.io_callback_p) @register_check(dispatch.device_put_p) -def _device_put_rule(mesh, x, **_): - return x +def _device_put_rule(mesh, *xs, **_): + return list(xs) register_norewrite(dispatch.device_put_p)