mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add TransferToMemoryKind
as a private API to allow device_put to transfer to different memories without specifying the sharding and allowing the SPMD partitioner to choose the sharding for the intermediate.
Exposing it as a public API can be done later. PiperOrigin-RevId: 559314369
This commit is contained in:
parent
bad217b2f8
commit
aeb62cc006
@ -66,7 +66,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import PmapSharding
|
||||
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src import tree_util
|
||||
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
|
||||
@ -2489,8 +2489,8 @@ def _infer_src_sharding(src, x):
|
||||
|
||||
def device_put(
|
||||
x,
|
||||
device: None | xc.Device | Sharding | Any = None,
|
||||
*, src: None | xc.Device | Sharding | Any = None):
|
||||
device: None | xc.Device | Sharding | Any | TransferToMemoryKind = None,
|
||||
*, src: None | xc.Device | Sharding | Any | TransferToMemoryKind = None):
|
||||
"""Transfers ``x`` to ``device``.
|
||||
|
||||
Args:
|
||||
@ -2514,8 +2514,10 @@ def device_put(
|
||||
blocking the calling Python thread until any transfers are completed.
|
||||
"""
|
||||
with config_explicit_device_put_scope():
|
||||
if ((device is None or isinstance(device, (xc.Device, Sharding))) and
|
||||
(src is None or isinstance(src, (xc.Device, Sharding)))):
|
||||
if ((device is None or
|
||||
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
|
||||
(src is None or
|
||||
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
|
||||
return tree_map(
|
||||
lambda y: dispatch.device_put_p.bind(
|
||||
y, device=device, src=_infer_src_sharding(src, y)), x)
|
||||
@ -2524,8 +2526,8 @@ def device_put(
|
||||
device_flat = flatten_axes("device_put device", treedef, device)
|
||||
src_flat = flatten_axes("device_put source", treedef, src)
|
||||
out_flat = [
|
||||
dispatch.device_put_p.bind(y, device=d, src=_infer_src_sharding(s, y))
|
||||
for y, d, s in zip(x_flat, device_flat, src_flat)
|
||||
dispatch.device_put_p.bind(xf, device=d, src=_infer_src_sharding(s, xf))
|
||||
for xf, d, s in zip(x_flat, device_flat, src_flat)
|
||||
]
|
||||
return tree_unflatten(treedef, out_flat)
|
||||
|
||||
|
@ -49,7 +49,7 @@ from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
|
||||
UNSPECIFIED, GSPMDSharding)
|
||||
UNSPECIFIED, GSPMDSharding, TransferToMemoryKind)
|
||||
|
||||
|
||||
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
||||
@ -467,6 +467,14 @@ def _device_put_impl(
|
||||
device: Device | Sharding | None = None,
|
||||
src: Device | Sharding | None = None):
|
||||
from jax._src import array
|
||||
|
||||
if (isinstance(device, TransferToMemoryKind) or
|
||||
isinstance(src, TransferToMemoryKind)):
|
||||
raise ValueError(
|
||||
"TransferToMemoryKind argument to jax.device_put can only be used"
|
||||
" inside jax.jit. If you are using device_put outside jax.jit, then"
|
||||
" please provide a concrete Sharding with memory_kind.")
|
||||
|
||||
try:
|
||||
aval = xla.abstractify(x)
|
||||
except TypeError as err:
|
||||
@ -521,12 +529,14 @@ ad.deflinear2(device_put_p, device_put_transpose_rule)
|
||||
batching.defvectorized(device_put_p)
|
||||
|
||||
def _device_put_lowering(ctx, x, *, device, src):
|
||||
if isinstance(device, XLACompatibleSharding) and device.memory_kind is not None:
|
||||
if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and
|
||||
device.memory_kind is not None):
|
||||
aval, = ctx.avals_in
|
||||
out_aval, = ctx.avals_out
|
||||
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
|
||||
x = mlir.wrap_with_sharding_op(
|
||||
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||
if isinstance(device, XLACompatibleSharding):
|
||||
x = mlir.wrap_with_sharding_op(
|
||||
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||
return [x]
|
||||
return [x]
|
||||
mlir.register_lowering(device_put_p, _device_put_lowering)
|
||||
|
@ -1925,6 +1925,17 @@ def _create_da_object( # pytype: disable=invalid-annotation
|
||||
return _DeviceAssignment(device_assignment)
|
||||
|
||||
|
||||
def jaxpr_has_dp_with_transfer_mem_kind(jaxpr: core.Jaxpr) -> bool:
|
||||
for eqn in jaxpr.eqns:
|
||||
if (eqn.primitive is dispatch.device_put_p and
|
||||
isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)):
|
||||
return True
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
if jaxpr_has_dp_with_transfer_mem_kind(subjaxpr):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_sharding_computation(
|
||||
fun_or_jaxpr: lu.WrappedFun | core.ClosedJaxpr,
|
||||
@ -1983,7 +1994,8 @@ def lower_sharding_computation(
|
||||
len(device_assignment) > 1 or
|
||||
any(not is_unspecified(i) for i in in_shardings) or
|
||||
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
|
||||
any(not is_unspecified(o) for o in out_shardings))
|
||||
any(not is_unspecified(o) for o in out_shardings) or
|
||||
jaxpr_has_dp_with_transfer_mem_kind(jaxpr))
|
||||
|
||||
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
|
||||
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
|
||||
|
@ -58,6 +58,11 @@ _ENABLE_MEMORY_KIND = jax_config.DEFINE_bool(
|
||||
"and annotate Shardings with it."))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TransferToMemoryKind:
|
||||
memory_kind: str
|
||||
|
||||
|
||||
# Shardings that inherit from XLACompatibleSharding should implement the
|
||||
# `_device_assignment` property and `_to_xla_hlo_sharding` method.
|
||||
@use_cpp_class(xc.XLACompatibleSharding)
|
||||
|
Loading…
x
Reference in New Issue
Block a user