Add TransferToMemoryKind as a private API to allow device_put to transfer to different memories without specifying the sharding and allowing the SPMD partitioner to choose the sharding for the intermediate.

Exposing it as a public API can be done later.

PiperOrigin-RevId: 559314369
This commit is contained in:
Yash Katariya 2023-08-22 22:07:24 -07:00 committed by jax authors
parent bad217b2f8
commit aeb62cc006
4 changed files with 41 additions and 12 deletions

View File

@ -66,7 +66,7 @@ from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import PmapSharding
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
@ -2489,8 +2489,8 @@ def _infer_src_sharding(src, x):
def device_put(
x,
device: None | xc.Device | Sharding | Any = None,
*, src: None | xc.Device | Sharding | Any = None):
device: None | xc.Device | Sharding | Any | TransferToMemoryKind = None,
*, src: None | xc.Device | Sharding | Any | TransferToMemoryKind = None):
"""Transfers ``x`` to ``device``.
Args:
@ -2514,8 +2514,10 @@ def device_put(
blocking the calling Python thread until any transfers are completed.
"""
with config_explicit_device_put_scope():
if ((device is None or isinstance(device, (xc.Device, Sharding))) and
(src is None or isinstance(src, (xc.Device, Sharding)))):
if ((device is None or
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
(src is None or
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
return tree_map(
lambda y: dispatch.device_put_p.bind(
y, device=device, src=_infer_src_sharding(src, y)), x)
@ -2524,8 +2526,8 @@ def device_put(
device_flat = flatten_axes("device_put device", treedef, device)
src_flat = flatten_axes("device_put source", treedef, src)
out_flat = [
dispatch.device_put_p.bind(y, device=d, src=_infer_src_sharding(s, y))
for y, d, s in zip(x_flat, device_flat, src_flat)
dispatch.device_put_p.bind(xf, device=d, src=_infer_src_sharding(s, xf))
for xf, d, s in zip(x_flat, device_flat, src_flat)
]
return tree_unflatten(treedef, out_flat)

View File

@ -49,7 +49,7 @@ from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
UNSPECIFIED, GSPMDSharding)
UNSPECIFIED, GSPMDSharding, TransferToMemoryKind)
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
@ -467,6 +467,14 @@ def _device_put_impl(
device: Device | Sharding | None = None,
src: Device | Sharding | None = None):
from jax._src import array
if (isinstance(device, TransferToMemoryKind) or
isinstance(src, TransferToMemoryKind)):
raise ValueError(
"TransferToMemoryKind argument to jax.device_put can only be used"
" inside jax.jit. If you are using device_put outside jax.jit, then"
" please provide a concrete Sharding with memory_kind.")
try:
aval = xla.abstractify(x)
except TypeError as err:
@ -521,12 +529,14 @@ ad.deflinear2(device_put_p, device_put_transpose_rule)
batching.defvectorized(device_put_p)
def _device_put_lowering(ctx, x, *, device, src):
if isinstance(device, XLACompatibleSharding) and device.memory_kind is not None:
if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and
device.memory_kind is not None):
aval, = ctx.avals_in
out_aval, = ctx.avals_out
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
if isinstance(device, XLACompatibleSharding):
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
return [x]
return [x]
mlir.register_lowering(device_put_p, _device_put_lowering)

View File

@ -1925,6 +1925,17 @@ def _create_da_object( # pytype: disable=invalid-annotation
return _DeviceAssignment(device_assignment)
def jaxpr_has_dp_with_transfer_mem_kind(jaxpr: core.Jaxpr) -> bool:
for eqn in jaxpr.eqns:
if (eqn.primitive is dispatch.device_put_p and
isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)):
return True
for subjaxpr in core.subjaxprs(jaxpr):
if jaxpr_has_dp_with_transfer_mem_kind(subjaxpr):
return True
return False
@profiler.annotate_function
def lower_sharding_computation(
fun_or_jaxpr: lu.WrappedFun | core.ClosedJaxpr,
@ -1983,7 +1994,8 @@ def lower_sharding_computation(
len(device_assignment) > 1 or
any(not is_unspecified(i) for i in in_shardings) or
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not is_unspecified(o) for o in out_shardings))
any(not is_unspecified(o) for o in out_shardings) or
jaxpr_has_dp_with_transfer_mem_kind(jaxpr))
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)

View File

@ -58,6 +58,11 @@ _ENABLE_MEMORY_KIND = jax_config.DEFINE_bool(
"and annotate Shardings with it."))
@dataclasses.dataclass(frozen=True)
class TransferToMemoryKind:
memory_kind: str
# Shardings that inherit from XLACompatibleSharding should implement the
# `_device_assignment` property and `_to_xla_hlo_sharding` method.
@use_cpp_class(xc.XLACompatibleSharding)