Merge pull request #12675 from mattjj:device-put2

PiperOrigin-RevId: 479660808
This commit is contained in:
jax authors 2022-10-07 13:49:57 -07:00
commit 58cd8376ee
5 changed files with 63 additions and 13 deletions

View File

@ -19,6 +19,7 @@ arguments and outputs. The Python containers handled are pytrees (see
tree_util.py), which include nested tuples/lists/dicts, where the leaves are
arrays.
"""
from __future__ import annotations
import collections
import functools
@ -2794,13 +2795,18 @@ def make_jaxpr(fun: Callable,
return make_jaxpr_f
def device_put(x, device: Optional[xc.Device] = None):
def device_put(
x, device: Optional[Union[xc.Device, jax.sharding.Sharding]] = None):
"""Transfers ``x`` to ``device``.
Args:
x: An array, scalar, or (nested) standard Python container thereof.
device: The (optional) :py:class:`Device` to which ``x`` should be
transferred. If given, then the result is committed to the device.
device: The (optional) :py:class:`Device` or `Sharding` representing the
device(s) to which ``x`` should be transferred. If given, then the result
is committed to the device(s).
Returns:
A copy of ``x`` that resides on ``device``.
If the ``device`` parameter is ``None``, then this operation behaves like the
identity function if the operand is on any device already, otherwise it
@ -2809,10 +2815,8 @@ def device_put(x, device: Optional[xc.Device] = None):
For more details on data placement see the
:ref:`FAQ on data placement <faq-data-placement>`.
This function is always asynchronous, i.e. returns immediately.
Returns:
A copy of ``x`` that resides on ``device``.
This function is always asynchronous, i.e. returns immediately without
blocking the calling Python thread until any transfers are completed.
"""
with config_explicit_device_put_scope():
return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)

View File

@ -1292,15 +1292,34 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra
committed=(device is not None))
def _device_put_impl(x, device: Optional[Device] = None):
def _device_put_impl(
x, device: Optional[Union[Device, jax.sharding.Sharding]] = None):
from jax._src import array, sharding
if isinstance(device, sharding.Sharding):
if not device.is_fully_addressable(): # type: ignore
raise ValueError(
"device_put's second argument must be a Device or a Sharding which "
f"represents addressable devices, but got {sharding}")
if getattr(x, 'sharding', None) == device:
return x
# TODO(mattjj,yashkatariya,phawkins): runtime fast resharding here?
return array.make_array_from_callback(x.shape, device, lambda idx: x[idx])
# Only `Device` exists below. `Sharding` instance is handled above.
if isinstance(x, array.ArrayImpl):
if not x.is_fully_addressable():
raise ValueError(
"device_put's first argument must be a fully addressable array, but "
f"got value with devices {x.devices()}")
if device is None:
return x
elif is_single_device_sharding(x.sharding):
return _copy_array_to_device(x, device)
if device_array.type_is_device_array(x):
return _copy_device_array_to_device(x, device)
if type(x) is array.ArrayImpl and isinstance(x.sharding, sharding.SingleDeviceSharding):
return _copy_array_to_device(x, device)
try:
a = xla.abstractify(x)
except TypeError as err:

View File

@ -1928,7 +1928,7 @@ def _convert_to_array_if_dtype_fails(x):
def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Any = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "asarray")
dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
return array(a, dtype=dtype, copy=False, order=order)
return array(a, dtype=dtype, copy=False, order=order) # type: ignore
@_wraps(np.copy, lax_description=_ARRAY_DOC)

View File

@ -3462,7 +3462,7 @@ def is_op_sharding_replicated(op: xc.OpSharding) -> bool:
if xla_extension_version >= 82:
if len(op.tile_assignment_devices) == 1:
return True
return xc.HloSharding.from_proto(op).is_replicated()
return xc.HloSharding.from_proto(op).is_replicated() # type: ignore
else:
return op.type == xc.OpSharding.Type.REPLICATED

View File

@ -56,6 +56,7 @@ from jax.interpreters import partial_eval as pe
from jax.interpreters.pxla import PartitionSpec as P
from jax._src import array, sharding
from jax.experimental import pjit
from jax.experimental import maps
from jax._src import config as jax_config
from jax._src import custom_derivatives
from jax._src import device_array
@ -1486,6 +1487,32 @@ class APITest(jtu.JaxTestCase):
self.assertIsInstance(y2[1][1], np.ndarray)
assert np.all(y2[1][1] == 3 * x)
def test_device_put_sharding(self):
mesh = maps.Mesh(jax.devices(), ('x',))
s = sharding.MeshPspecSharding(mesh, P('x'))
x = jnp.arange(len(jax.devices()))
y = jax.device_put(x, s)
self.assertEqual(y.sharding, s)
self.assertArraysAllClose(y, x)
# this might hit a special fast path
z = jax.device_put(y, s)
self.assertEqual(z.sharding, s)
self.assertArraysAllClose(z, x)
self.assertIs(z, y) # no copy
w = jax.device_put(z)
self.assertIs(w, z)
u = jax.device_put(y, jax.devices()[0])
self.assertArraysAllClose(u, y)
self.assertEqual(u.device(), jax.devices()[0])
# TODO(frostig): make this pass with JAX_ENABLE_CUSTOM_PRNG=1
# # this can cover opaque dtypes
# x = jax.random.split(jax.random.PRNGKey(0), len(jax.devices()))
# jax.device_put(x, s) # doesn't crash
def test_device_get_scalar(self):
x = np.arange(12.).reshape((3, 4)).astype("float32")
x = api.device_put(x)