From aeb62cc00638675accebe8f62f30f1512bf17cad Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 22 Aug 2023 22:07:24 -0700 Subject: [PATCH] Add `TransferToMemoryKind` as a private API to allow device_put to transfer to different memories without specifying the sharding and allowing the SPMD partitioner to choose the sharding for the intermediate. Exposing it as a public API can be done later. PiperOrigin-RevId: 559314369 --- jax/_src/api.py | 16 +++++++++------- jax/_src/dispatch.py | 18 ++++++++++++++---- jax/_src/interpreters/pxla.py | 14 +++++++++++++- jax/_src/sharding_impls.py | 5 +++++ 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 32648decc..69ea33572 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -66,7 +66,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version from jax._src.lib import pmap_lib from jax._src.sharding import Sharding -from jax._src.sharding_impls import PmapSharding +from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind from jax._src.traceback_util import api_boundary from jax._src import tree_util from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps @@ -2489,8 +2489,8 @@ def _infer_src_sharding(src, x): def device_put( x, - device: None | xc.Device | Sharding | Any = None, - *, src: None | xc.Device | Sharding | Any = None): + device: None | xc.Device | Sharding | Any | TransferToMemoryKind = None, + *, src: None | xc.Device | Sharding | Any | TransferToMemoryKind = None): """Transfers ``x`` to ``device``. Args: @@ -2514,8 +2514,10 @@ def device_put( blocking the calling Python thread until any transfers are completed. """ with config_explicit_device_put_scope(): - if ((device is None or isinstance(device, (xc.Device, Sharding))) and - (src is None or isinstance(src, (xc.Device, Sharding)))): + if ((device is None or + isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and + (src is None or + isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))): return tree_map( lambda y: dispatch.device_put_p.bind( y, device=device, src=_infer_src_sharding(src, y)), x) @@ -2524,8 +2526,8 @@ def device_put( device_flat = flatten_axes("device_put device", treedef, device) src_flat = flatten_axes("device_put source", treedef, src) out_flat = [ - dispatch.device_put_p.bind(y, device=d, src=_infer_src_sharding(s, y)) - for y, d, s in zip(x_flat, device_flat, src_flat) + dispatch.device_put_p.bind(xf, device=d, src=_infer_src_sharding(s, xf)) + for xf, d, s in zip(x_flat, device_flat, src_flat) ] return tree_unflatten(treedef, out_flat) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index a8c6d2e29..46fab6cd0 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -49,7 +49,7 @@ from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding, - UNSPECIFIED, GSPMDSharding) + UNSPECIFIED, GSPMDSharding, TransferToMemoryKind) JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration" @@ -467,6 +467,14 @@ def _device_put_impl( device: Device | Sharding | None = None, src: Device | Sharding | None = None): from jax._src import array + + if (isinstance(device, TransferToMemoryKind) or + isinstance(src, TransferToMemoryKind)): + raise ValueError( + "TransferToMemoryKind argument to jax.device_put can only be used" + " inside jax.jit. If you are using device_put outside jax.jit, then" + " please provide a concrete Sharding with memory_kind.") + try: aval = xla.abstractify(x) except TypeError as err: @@ -521,12 +529,14 @@ ad.deflinear2(device_put_p, device_put_transpose_rule) batching.defvectorized(device_put_p) def _device_put_lowering(ctx, x, *, device, src): - if isinstance(device, XLACompatibleSharding) and device.memory_kind is not None: + if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and + device.memory_kind is not None): aval, = ctx.avals_in out_aval, = ctx.avals_out x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval) - x = mlir.wrap_with_sharding_op( - ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) + if isinstance(device, XLACompatibleSharding): + x = mlir.wrap_with_sharding_op( + ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) return [x] return [x] mlir.register_lowering(device_put_p, _device_put_lowering) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 4ae94e671..e5136a366 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1925,6 +1925,17 @@ def _create_da_object( # pytype: disable=invalid-annotation return _DeviceAssignment(device_assignment) +def jaxpr_has_dp_with_transfer_mem_kind(jaxpr: core.Jaxpr) -> bool: + for eqn in jaxpr.eqns: + if (eqn.primitive is dispatch.device_put_p and + isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)): + return True + for subjaxpr in core.subjaxprs(jaxpr): + if jaxpr_has_dp_with_transfer_mem_kind(subjaxpr): + return True + return False + + @profiler.annotate_function def lower_sharding_computation( fun_or_jaxpr: lu.WrappedFun | core.ClosedJaxpr, @@ -1983,7 +1994,8 @@ def lower_sharding_computation( len(device_assignment) > 1 or any(not is_unspecified(i) for i in in_shardings) or any(not is_unspecified(js) for js, _ in jaxpr_sharding) or - any(not is_unspecified(o) for o in out_shardings)) + any(not is_unspecified(o) for o in out_shardings) or + jaxpr_has_dp_with_transfer_mem_kind(jaxpr)) gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment) in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index cbc5ffb76..efa1c5896 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -58,6 +58,11 @@ _ENABLE_MEMORY_KIND = jax_config.DEFINE_bool( "and annotate Shardings with it.")) +@dataclasses.dataclass(frozen=True) +class TransferToMemoryKind: + memory_kind: str + + # Shardings that inherit from XLACompatibleSharding should implement the # `_device_assignment` property and `_to_xla_hlo_sharding` method. @use_cpp_class(xc.XLACompatibleSharding)