mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #12675 from mattjj:device-put2
PiperOrigin-RevId: 479660808
This commit is contained in:
commit
58cd8376ee
@ -19,6 +19,7 @@ arguments and outputs. The Python containers handled are pytrees (see
|
||||
tree_util.py), which include nested tuples/lists/dicts, where the leaves are
|
||||
arrays.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import functools
|
||||
@ -2794,13 +2795,18 @@ def make_jaxpr(fun: Callable,
|
||||
return make_jaxpr_f
|
||||
|
||||
|
||||
def device_put(x, device: Optional[xc.Device] = None):
|
||||
def device_put(
|
||||
x, device: Optional[Union[xc.Device, jax.sharding.Sharding]] = None):
|
||||
"""Transfers ``x`` to ``device``.
|
||||
|
||||
Args:
|
||||
x: An array, scalar, or (nested) standard Python container thereof.
|
||||
device: The (optional) :py:class:`Device` to which ``x`` should be
|
||||
transferred. If given, then the result is committed to the device.
|
||||
device: The (optional) :py:class:`Device` or `Sharding` representing the
|
||||
device(s) to which ``x`` should be transferred. If given, then the result
|
||||
is committed to the device(s).
|
||||
|
||||
Returns:
|
||||
A copy of ``x`` that resides on ``device``.
|
||||
|
||||
If the ``device`` parameter is ``None``, then this operation behaves like the
|
||||
identity function if the operand is on any device already, otherwise it
|
||||
@ -2809,10 +2815,8 @@ def device_put(x, device: Optional[xc.Device] = None):
|
||||
For more details on data placement see the
|
||||
:ref:`FAQ on data placement <faq-data-placement>`.
|
||||
|
||||
This function is always asynchronous, i.e. returns immediately.
|
||||
|
||||
Returns:
|
||||
A copy of ``x`` that resides on ``device``.
|
||||
This function is always asynchronous, i.e. returns immediately without
|
||||
blocking the calling Python thread until any transfers are completed.
|
||||
"""
|
||||
with config_explicit_device_put_scope():
|
||||
return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
|
||||
|
@ -1292,15 +1292,34 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra
|
||||
committed=(device is not None))
|
||||
|
||||
|
||||
def _device_put_impl(x, device: Optional[Device] = None):
|
||||
def _device_put_impl(
|
||||
x, device: Optional[Union[Device, jax.sharding.Sharding]] = None):
|
||||
from jax._src import array, sharding
|
||||
|
||||
if isinstance(device, sharding.Sharding):
|
||||
if not device.is_fully_addressable(): # type: ignore
|
||||
raise ValueError(
|
||||
"device_put's second argument must be a Device or a Sharding which "
|
||||
f"represents addressable devices, but got {sharding}")
|
||||
if getattr(x, 'sharding', None) == device:
|
||||
return x
|
||||
# TODO(mattjj,yashkatariya,phawkins): runtime fast resharding here?
|
||||
return array.make_array_from_callback(x.shape, device, lambda idx: x[idx])
|
||||
|
||||
# Only `Device` exists below. `Sharding` instance is handled above.
|
||||
if isinstance(x, array.ArrayImpl):
|
||||
if not x.is_fully_addressable():
|
||||
raise ValueError(
|
||||
"device_put's first argument must be a fully addressable array, but "
|
||||
f"got value with devices {x.devices()}")
|
||||
if device is None:
|
||||
return x
|
||||
elif is_single_device_sharding(x.sharding):
|
||||
return _copy_array_to_device(x, device)
|
||||
|
||||
if device_array.type_is_device_array(x):
|
||||
return _copy_device_array_to_device(x, device)
|
||||
|
||||
if type(x) is array.ArrayImpl and isinstance(x.sharding, sharding.SingleDeviceSharding):
|
||||
return _copy_array_to_device(x, device)
|
||||
|
||||
try:
|
||||
a = xla.abstractify(x)
|
||||
except TypeError as err:
|
||||
|
@ -1928,7 +1928,7 @@ def _convert_to_array_if_dtype_fails(x):
|
||||
def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Any = None) -> Array:
|
||||
lax_internal._check_user_dtype_supported(dtype, "asarray")
|
||||
dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
|
||||
return array(a, dtype=dtype, copy=False, order=order)
|
||||
return array(a, dtype=dtype, copy=False, order=order) # type: ignore
|
||||
|
||||
|
||||
@_wraps(np.copy, lax_description=_ARRAY_DOC)
|
||||
|
@ -3462,7 +3462,7 @@ def is_op_sharding_replicated(op: xc.OpSharding) -> bool:
|
||||
if xla_extension_version >= 82:
|
||||
if len(op.tile_assignment_devices) == 1:
|
||||
return True
|
||||
return xc.HloSharding.from_proto(op).is_replicated()
|
||||
return xc.HloSharding.from_proto(op).is_replicated() # type: ignore
|
||||
else:
|
||||
return op.type == xc.OpSharding.Type.REPLICATED
|
||||
|
||||
|
@ -56,6 +56,7 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters.pxla import PartitionSpec as P
|
||||
from jax._src import array, sharding
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental import maps
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import device_array
|
||||
@ -1486,6 +1487,32 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(y2[1][1], np.ndarray)
|
||||
assert np.all(y2[1][1] == 3 * x)
|
||||
|
||||
def test_device_put_sharding(self):
|
||||
mesh = maps.Mesh(jax.devices(), ('x',))
|
||||
s = sharding.MeshPspecSharding(mesh, P('x'))
|
||||
x = jnp.arange(len(jax.devices()))
|
||||
y = jax.device_put(x, s)
|
||||
self.assertEqual(y.sharding, s)
|
||||
self.assertArraysAllClose(y, x)
|
||||
|
||||
# this might hit a special fast path
|
||||
z = jax.device_put(y, s)
|
||||
self.assertEqual(z.sharding, s)
|
||||
self.assertArraysAllClose(z, x)
|
||||
self.assertIs(z, y) # no copy
|
||||
|
||||
w = jax.device_put(z)
|
||||
self.assertIs(w, z)
|
||||
|
||||
u = jax.device_put(y, jax.devices()[0])
|
||||
self.assertArraysAllClose(u, y)
|
||||
self.assertEqual(u.device(), jax.devices()[0])
|
||||
|
||||
# TODO(frostig): make this pass with JAX_ENABLE_CUSTOM_PRNG=1
|
||||
# # this can cover opaque dtypes
|
||||
# x = jax.random.split(jax.random.PRNGKey(0), len(jax.devices()))
|
||||
# jax.device_put(x, s) # doesn't crash
|
||||
|
||||
def test_device_get_scalar(self):
|
||||
x = np.arange(12.).reshape((3, 4)).astype("float32")
|
||||
x = api.device_put(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user