[JAX] Deprecate .to_py() property on arrays. Implement __array__ instead.

.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant.

PiperOrigin-RevId: 469984029
This commit is contained in:
Peter Hawkins 2022-08-25 07:27:54 -07:00 committed by jax authors
parent fd3a72dd1f
commit 5527966b27
17 changed files with 89 additions and 75 deletions

View File

@ -22,6 +22,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Deprecations: * Deprecations:
* The deprecated `DeviceArray.tile()` method has been removed. Use {func}`jax.numpy.tile` * The deprecated `DeviceArray.tile()` method has been removed. Use {func}`jax.numpy.tile`
({jax-issue}`#11944`). ({jax-issue}`#11944`).
* `DeviceArray.to_py()` has been deprecated. Use `np.asarray(x)` instead.
## jax 0.3.16 ## jax 0.3.16
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main). * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).

View File

@ -2066,7 +2066,7 @@
"\n", "\n",
"def handle_result(aval: ShapedArray, buf):\n", "def handle_result(aval: ShapedArray, buf):\n",
" del aval # Unused for now\n", " del aval # Unused for now\n",
" return buf.to_py()\n", " return np.asarray(buf)\n",
"\n", "\n",
"xla_translations = {}" "xla_translations = {}"
] ]
@ -2370,9 +2370,9 @@
" shape = property(lambda self: self.aval.shape)\n", " shape = property(lambda self: self.aval.shape)\n",
" ndim = property(lambda self: self.aval.ndim)\n", " ndim = property(lambda self: self.aval.ndim)\n",
"\n", "\n",
" def __array__(self): return self.buf.to_py()\n", " def __array__(self): return np.asarray(self.buf)\n",
" def __repr__(self): return repr(self.buf.to_py())\n", " def __repr__(self): return repr(np.asarray(self.buf))\n",
" def __str__(self): return str(self.buf.to_py())\n", " def __str__(self): return str(np.asarray(self.buf))\n",
"\n", "\n",
" _neg = staticmethod(neg)\n", " _neg = staticmethod(neg)\n",
" _add = staticmethod(add)\n", " _add = staticmethod(add)\n",

View File

@ -1626,7 +1626,7 @@ input_handlers = {ty: default_input_handler for ty in
def handle_result(aval: ShapedArray, buf): def handle_result(aval: ShapedArray, buf):
del aval # Unused for now del aval # Unused for now
return buf.to_py() return np.asarray(buf)
xla_translations = {} xla_translations = {}
``` ```
@ -1842,9 +1842,9 @@ class DeviceArray:
shape = property(lambda self: self.aval.shape) shape = property(lambda self: self.aval.shape)
ndim = property(lambda self: self.aval.ndim) ndim = property(lambda self: self.aval.ndim)
def __array__(self): return self.buf.to_py() def __array__(self): return np.asarray(self.buf)
def __repr__(self): return repr(self.buf.to_py()) def __repr__(self): return repr(np.asarray(self.buf))
def __str__(self): return str(self.buf.to_py()) def __str__(self): return str(np.asarray(self.buf))
_neg = staticmethod(neg) _neg = staticmethod(neg)
_add = staticmethod(add) _add = staticmethod(add)

View File

@ -1620,7 +1620,7 @@ input_handlers = {ty: default_input_handler for ty in
def handle_result(aval: ShapedArray, buf): def handle_result(aval: ShapedArray, buf):
del aval # Unused for now del aval # Unused for now
return buf.to_py() return np.asarray(buf)
xla_translations = {} xla_translations = {}
@ -1833,9 +1833,9 @@ class DeviceArray:
shape = property(lambda self: self.aval.shape) shape = property(lambda self: self.aval.shape)
ndim = property(lambda self: self.aval.ndim) ndim = property(lambda self: self.aval.ndim)
def __array__(self): return self.buf.to_py() def __array__(self): return np.asarray(self.buf)
def __repr__(self): return repr(self.buf.to_py()) def __repr__(self): return repr(np.asarray(self.buf))
def __str__(self): return str(self.buf.to_py()) def __str__(self): return str(np.asarray(self.buf))
_neg = staticmethod(neg) _neg = staticmethod(neg)
_add = staticmethod(add) _add = staticmethod(add)

View File

@ -18,6 +18,7 @@ from functools import partial, partialmethod
import operator import operator
from typing import (Any, List, Optional, Union) from typing import (Any, List, Optional, Union)
import weakref import weakref
import warnings
import numpy as np import numpy as np
@ -146,7 +147,7 @@ class _DeviceArray(DeviceArray): # type: ignore
def _value(self): def _value(self):
self._check_if_deleted() self._check_if_deleted()
if self._npy_value is None: if self._npy_value is None:
self._npy_value = self.device_buffer.to_py() # pytype: disable=attribute-error # bind-properties self._npy_value = np.asarray(self.device_buffer) # pytype: disable=attribute-error # bind-properties
self._npy_value.flags.writeable = False self._npy_value.flags.writeable = False
return self._npy_value return self._npy_value
@ -266,6 +267,15 @@ for device_array in [DeviceArray]:
setattr(device_array, "__array__", __array__) setattr(device_array, "__array__", __array__)
# TODO(phawkins): delete this code path after the deprecation for .to_py()
# expires in Nov 2022.
def to_py(self):
warnings.warn("The .to_py() method on JAX arrays is deprecated. Use "
"np.asarray(...) instead.", category=FutureWarning)
return np.asarray(self._value)
setattr(device_array, "to_py", to_py)
def __dlpack__(self): def __dlpack__(self):
from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
return to_dlpack(self) return to_dlpack(self)

View File

@ -779,9 +779,9 @@ def check_special(name, bufs):
def _check_special(name, xla_shape, buf): def _check_special(name, xla_shape, buf):
assert not xla_shape.is_tuple() assert not xla_shape.is_tuple()
if dtypes.issubdtype(xla_shape.element_type(), np.inexact): if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
if config.jax_debug_nans and np.any(np.isnan(buf.to_py())): if config.jax_debug_nans and np.any(np.isnan(np.asarray(buf))):
raise FloatingPointError(f"invalid value (nan) encountered in {name}") raise FloatingPointError(f"invalid value (nan) encountered in {name}")
if config.jax_debug_infs and np.any(np.isinf(buf.to_py())): if config.jax_debug_infs and np.any(np.isinf(np.asarray(buf))):
raise FloatingPointError(f"invalid value (inf) encountered in {name}") raise FloatingPointError(f"invalid value (inf) encountered in {name}")
def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect], def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
@ -1155,7 +1155,7 @@ def _copy_device_array_to_device(
else: else:
# buffers from different XLA backends are passed through the host. # buffers from different XLA backends are passed through the host.
backend = xb.get_device_backend(device) backend = xb.get_device_backend(device)
moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device) moved_buf = backend.buffer_from_pyval(np.asarray(x.device_buffer), device)
return device_array.make_device_array(x.aval, device, moved_buf) return device_array.make_device_array(x.aval, device, moved_buf)
@ -1182,7 +1182,7 @@ def _copy_array_to_device(x: Array, device: Optional[xc.Device]) -> Array:
else: else:
# buffers from different XLA backends are passed through the host. # buffers from different XLA backends are passed through the host.
backend = xb.get_device_backend(device) backend = xb.get_device_backend(device)
moved_buf = backend.buffer_from_pyval(buf.to_py(), device) moved_buf = backend.buffer_from_pyval(np.asarray(buf), device)
return array.Array( return array.Array(
x.aval, sharding.SingleDeviceSharding(moved_buf.device()), [moved_buf], x.aval, sharding.SingleDeviceSharding(moved_buf.device()), [moved_buf],
committed=(device is not None)) committed=(device is not None))

View File

@ -88,7 +88,7 @@ class IreeBuffer(xla_client.DeviceArrayBase):
def copy_to_device(self, device): def copy_to_device(self, device):
return self return self
def to_py(self) -> np.ndarray: def __array__(self, dtype=None, context=None):
return np.asarray(self._buffer) return np.asarray(self._buffer)
def to_iree(self): def to_iree(self):
@ -104,8 +104,11 @@ class IreeBuffer(xla_client.DeviceArrayBase):
return self # no async return self # no async
# overrides repr on base class which expects _value and aval attributes # overrides repr on base class which expects _value and aval attributes
def __repr__(self): return f'IreeBuffer({self.to_py()})' def __repr__(self): return f'IreeBuffer({np.asarray(self)})'
_value = property(to_py)
@property
def _value(self):
return np.asarray(self)
class IreeExecutable: class IreeExecutable:

View File

@ -290,9 +290,6 @@ class Array:
self._check_if_deleted() self._check_if_deleted()
return list(self.sharding.device_set) return list(self.sharding.device_set)
def to_py(self) -> np.ndarray:
return self._value
@pxla.maybe_cached_property @pxla.maybe_cached_property
def addressable_shards(self) -> Sequence[Shard]: def addressable_shards(self) -> Sequence[Shard]:
self._check_if_deleted() self._check_if_deleted()
@ -367,7 +364,7 @@ class Array:
for s in self.addressable_shards: for s in self.addressable_shards:
if not replica_id_exists or s.replica_id == 0: if not replica_id_exists or s.replica_id == 0:
npy_value[s.index] = s.data._arrays[0].to_py() # type: ignore # [union-attr] npy_value[s.index] = np.asarray(s.data._arrays[0]) # type: ignore # [union-attr]
self._npy_value = npy_value # type: ignore self._npy_value = npy_value # type: ignore
# https://docs.python.org/3/library/typing.html#typing.cast # https://docs.python.org/3/library/typing.html#typing.cast
return cast(np.ndarray, self._npy_value) return cast(np.ndarray, self._npy_value)

View File

@ -72,16 +72,16 @@ class CheckpointTest(jtu.JaxTestCase):
[mesh_axes, P('x'), P(None)], [mesh_axes, P('x'), P(None)],
tspecs) tspecs)
self.assertArraysEqual(m1.local_shards[0].data.to_py(), self.assertArraysEqual(np.asarray(m1.local_shards[0].data),
np.array([[0], [2]])) np.array([[0], [2]]))
self.assertArraysEqual(m1.local_shards[1].data.to_py(), self.assertArraysEqual(np.asarray(m1.local_shards[1].data),
np.array([[1], [3]])) np.array([[1], [3]]))
self.assertEqual(m1.local_shards[0].data.shape, (2, 1)) self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
self.assertEqual(m1.dtype, np.int32) self.assertEqual(m1.dtype, np.int32)
self.assertArraysEqual(m2.local_shards[0].data.to_py(), self.assertArraysEqual(np.asarray(m2.local_shards[0].data),
np.array([[16, 17], [18, 19]])) np.array([[16, 17], [18, 19]]))
self.assertArraysEqual(m2.local_shards[1].data.to_py(), self.assertArraysEqual(np.asarray(m2.local_shards[1].data),
np.array([[16, 17], [18, 19]])) np.array([[16, 17], [18, 19]]))
self.assertEqual(m2.local_shards[0].data.shape, (2, 2)) self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
self.assertEqual(m2.dtype, np.int32) self.assertEqual(m2.dtype, np.int32)
@ -89,7 +89,7 @@ class CheckpointTest(jtu.JaxTestCase):
for i, s in enumerate(m3.local_shards): for i, s in enumerate(m3.local_shards):
self.assertEqual(s.index, (slice(None),)) self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i) self.assertEqual(s.replica_id, i)
self.assertArraysEqual(s.data.to_py(), np.array([])) self.assertArraysEqual(np.asarray(s.data), np.array([]))
self.assertEqual(m3.dtype, np.float32) self.assertEqual(m3.dtype, np.float32)
@jax_config.jax_array(True) @jax_config.jax_array(True)
@ -132,17 +132,17 @@ class CheckpointTest(jtu.JaxTestCase):
tspecs) tspecs)
self.assertIsInstance(m1, array.Array) self.assertIsInstance(m1, array.Array)
self.assertArraysEqual(m1.addressable_shards[0].data.to_py(), self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data),
np.array([[0], [2]])) np.array([[0], [2]]))
self.assertArraysEqual(m1.addressable_shards[1].data.to_py(), self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data),
np.array([[1], [3]])) np.array([[1], [3]]))
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1)) self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
self.assertEqual(m1.dtype, np.int32) self.assertEqual(m1.dtype, np.int32)
self.assertIsInstance(m2, array.Array) self.assertIsInstance(m2, array.Array)
self.assertArraysEqual(m2.addressable_shards[0].data.to_py(), self.assertArraysEqual(np.asarray(m2.addressable_shards[0].data),
np.array([[16, 17], [18, 19]])) np.array([[16, 17], [18, 19]]))
self.assertArraysEqual(m2.addressable_shards[1].data.to_py(), self.assertArraysEqual(np.asarray(m2.addressable_shards[1].data),
np.array([[16, 17], [18, 19]])) np.array([[16, 17], [18, 19]]))
self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2)) self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
self.assertEqual(m2.dtype, np.int32) self.assertEqual(m2.dtype, np.int32)
@ -151,7 +151,7 @@ class CheckpointTest(jtu.JaxTestCase):
for i, s in enumerate(m3.addressable_shards): for i, s in enumerate(m3.addressable_shards):
self.assertEqual(s.index, (slice(None),)) self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i) self.assertEqual(s.replica_id, i)
self.assertArraysEqual(s.data.to_py(), np.array([])) self.assertArraysEqual(np.asarray(s.data), np.array([]))
self.assertEqual(m3.dtype, np.float32) self.assertEqual(m3.dtype, np.float32)
def test_checkpointing_with_bigger_shape(self): def test_checkpointing_with_bigger_shape(self):
@ -192,7 +192,7 @@ class CheckpointTest(jtu.JaxTestCase):
} }
for l in m1.local_shards: for l in m1.local_shards:
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id]) self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
def test_checkpointing_scalar(self): def test_checkpointing_scalar(self):
global_mesh = jtu.create_global_mesh((2,), ('x')) global_mesh = jtu.create_global_mesh((2,), ('x'))
@ -216,7 +216,7 @@ class CheckpointTest(jtu.JaxTestCase):
) )
for l in m1.local_shards: for l in m1.local_shards:
self.assertArraysEqual(l.data.to_py(), data.astype(np.float32)) self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
def test_spec_has_metadata(self): def test_spec_has_metadata(self):
spec = { spec = {

View File

@ -386,9 +386,12 @@ class GlobalDeviceArray:
for s in self.local_shards if s.replica_id == 0] for s in self.local_shards if s.replica_id == 0]
npy_value = np.empty(self.shape, self.dtype) npy_value = np.empty(self.shape, self.dtype)
for s in unique_shards: for s in unique_shards:
npy_value[s.index] = s.data.to_py() npy_value[s.index] = np.asarray(s.data)
return npy_value return npy_value
def __array__(self, dtype=None, context=None):
return self._value if dtype is None else self._value.astype(dtype)
def local_data(self, index) -> DeviceArray: def local_data(self, index) -> DeviceArray:
return pxla._set_aval(self._device_buffers[index]) return pxla._set_aval(self._device_buffers[index])

View File

@ -103,7 +103,7 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
def _pjit(inp): def _pjit(inp):
if isinstance(inp, GlobalDeviceArray): if isinstance(inp, GlobalDeviceArray):
if inp.is_fully_replicated: if inp.is_fully_replicated:
return inp.local_data(0).to_py() return np.asarray(inp.local_data(0))
global_mesh = inp.mesh global_mesh = inp.mesh
in_axis_resources = FROM_GDA in_axis_resources = FROM_GDA
else: else:
@ -119,7 +119,7 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
with maps.Mesh(global_mesh.devices, global_mesh.axis_names): with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
out = pjit(lambda x: x, in_axis_resources=in_axis_resources, out = pjit(lambda x: x, in_axis_resources=in_axis_resources,
out_axis_resources=None)(inp) out_axis_resources=None)(inp)
return out.local_data(0).to_py() return np.asarray(out.local_data(0))
with config_internal.parallel_functions_output_gda(True): with config_internal.parallel_functions_output_gda(True):
return jax.tree_util.tree_map(_pjit, in_tree) return jax.tree_util.tree_map(_pjit, in_tree)

View File

@ -303,7 +303,7 @@ for ptype, dtype in dtypes.python_scalar_dtypes.items():
register_constant_handler(ptype, partial(_python_scalar_handler, dtype)) register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
def _device_array_constant_handler(val, canonicalize_types): def _device_array_constant_handler(val, canonicalize_types):
return _ndarray_constant_handler(val.device_buffer.to_py(), return _ndarray_constant_handler(np.asarray(val.device_buffer),
canonicalize_types) canonicalize_types)
for t in device_array.device_array_types: for t in device_array.device_array_types:
register_constant_handler(t, _device_array_constant_handler) register_constant_handler(t, _device_array_constant_handler)

View File

@ -699,7 +699,7 @@ class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore
precomputed for efficiency. A list the same length as precomputed for efficiency. A list the same length as
`device_buffers`. Each index indicates what portion of the full array is `device_buffers`. Each index indicates what portion of the full array is
stored in the corresponding device buffer, i.e. `array[indices[i]] == stored in the corresponding device buffer, i.e. `array[indices[i]] ==
device_buffers[i].to_py()`. np.asarray(device_buffers[i])`.
""" """
__slots__ = [ __slots__ = [
"aval", "device_buffers", "sharding_spec", "indices", "aval", "device_buffers", "sharding_spec", "indices",
@ -792,7 +792,7 @@ def _sda_value(self):
self.copy_to_host_async() self.copy_to_host_async()
npy_value = np.empty(self.aval.shape, self.aval.dtype) npy_value = np.empty(self.aval.shape, self.aval.dtype)
for i in self.one_replica_buffer_indices: for i in self.one_replica_buffer_indices:
npy_value[self.indices[i]] = self.device_buffers[i].to_py() npy_value[self.indices[i]] = np.asarray(self.device_buffers[i])
self._npy_value = npy_value self._npy_value = npy_value
return self._npy_value return self._npy_value

View File

@ -194,7 +194,7 @@ class GDATest(jtu.JaxTestCase):
for i, s in enumerate(gda.local_shards): for i, s in enumerate(gda.local_shards):
self.assertEqual(s.index, (slice(None),)) self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i) self.assertEqual(s.replica_id, i)
self.assertArraysEqual(s.data.to_py(), np.array([])) self.assertArraysEqual(np.asarray(s.data), np.array([]))
self.assertEqual(gda.dtype, np.float32) self.assertEqual(gda.dtype, np.float32)
self.assertEqual( self.assertEqual(
gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes), gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes),
@ -249,10 +249,10 @@ class GDATest(jtu.JaxTestCase):
gda = GlobalDeviceArray.from_batched_callback( gda = GlobalDeviceArray.from_batched_callback(
global_input_shape, global_mesh, mesh_axes, cb) global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1]]) expected_first_shard_value = np.array([[0, 1]])
self.assertArraysEqual(gda.local_data(0).to_py(), self.assertArraysEqual(np.asarray(gda.local_data(0)),
expected_first_shard_value) expected_first_shard_value)
expected_second_shard_value = np.array([[2, 3]]) expected_second_shard_value = np.array([[2, 3]])
self.assertArraysEqual(gda.local_data(1).to_py(), self.assertArraysEqual(np.asarray(gda.local_data(1)),
expected_second_shard_value) expected_second_shard_value)
def test_gda_batched_callback_with_devices(self): def test_gda_batched_callback_with_devices(self):
@ -275,10 +275,10 @@ class GDATest(jtu.JaxTestCase):
gda = GlobalDeviceArray.from_batched_callback_with_devices( gda = GlobalDeviceArray.from_batched_callback_with_devices(
global_input_shape, global_mesh, mesh_axes, cb) global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32) expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(gda.local_data(0).to_py(), self.assertArraysEqual(np.asarray(gda.local_data(0)),
expected_first_shard_value) expected_first_shard_value)
expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32) expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(gda.local_data(1).to_py(), self.assertArraysEqual(np.asarray(gda.local_data(1)),
expected_second_shard_value) expected_second_shard_value)
def test_gda_str_repr(self): def test_gda_str_repr(self):

View File

@ -113,7 +113,7 @@ class JaxJitTest(jtu.JaxTestCase):
complex_type = dtypes.canonicalize_dtype(np.complex128) complex_type = dtypes.canonicalize_dtype(np.complex128)
# int # int
res = _cpp_device_put(1, device).to_py() res = np.asarray(_cpp_device_put(1, device))
self.assertEqual(res, 1) self.assertEqual(res, 1)
self.assertEqual(res.dtype, int_type) self.assertEqual(res.dtype, int_type)
# We also compare to the Python Jax API, to make sure we have the exact # We also compare to the Python Jax API, to make sure we have the exact
@ -122,20 +122,20 @@ class JaxJitTest(jtu.JaxTestCase):
self.assertEqual(jnp.asarray(1).dtype, res.dtype) self.assertEqual(jnp.asarray(1).dtype, res.dtype)
# float # float
res = _cpp_device_put(1.0, device).to_py() res = np.asarray(_cpp_device_put(1.0, device))
self.assertEqual(res, 1.0) self.assertEqual(res, 1.0)
self.assertEqual(res.dtype, float_type) self.assertEqual(res.dtype, float_type)
self.assertEqual(jnp.asarray(1.0).dtype, res.dtype) self.assertEqual(jnp.asarray(1.0).dtype, res.dtype)
# bool # bool
for bool_value in [True, False]: for bool_value in [True, False]:
res = _cpp_device_put(bool_value, device).to_py() res = np.asarray(_cpp_device_put(bool_value, device))
self.assertEqual(res, np.asarray(bool_value)) self.assertEqual(res, np.asarray(bool_value))
self.assertEqual(res.dtype, np.bool_) self.assertEqual(res.dtype, np.bool_)
self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype) self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype)
# Complex # Complex
res = _cpp_device_put(1 + 1j, device).to_py() res = np.asarray(_cpp_device_put(1 + 1j, device))
self.assertEqual(res, 1 + 1j) self.assertEqual(res, 1 + 1j)
self.assertEqual(res.dtype, complex_type) self.assertEqual(res.dtype, complex_type)
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype) self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)

View File

@ -135,7 +135,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 1) self.assertLen(actual.device_buffers, 1)
self.assertAllClose( self.assertAllClose(
actual.device_buffers[0].to_py(), expected, check_dtypes=False) np.asarray(actual.device_buffers[0]), expected, check_dtypes=False)
# Repro for a bug on device_buffer aval # Repro for a bug on device_buffer aval
_ = repr(actual.device_buffers) _ = repr(actual.device_buffers)
@ -154,7 +154,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False) self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 2) self.assertLen(actual.device_buffers, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), expected, self.assertAllClose(np.asarray(actual.device_buffers[0]), expected,
check_dtypes=False) check_dtypes=False)
@jtu.with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
@ -191,7 +191,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual[:3], expected[:3], check_dtypes=False) self.assertAllClose(actual[:3], expected[:3], check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 2) self.assertLen(actual.device_buffers, 2)
self.assertAllClose(actual.device_buffers[0].to_py()[:3], expected[:3], self.assertAllClose(np.asarray(actual.device_buffers[0])[:3], expected[:3],
check_dtypes=False) check_dtypes=False)
def testBasic1DWithMeshContextManager(self): def testBasic1DWithMeshContextManager(self):
@ -210,7 +210,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False) self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 2) self.assertLen(actual.device_buffers, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), expected, self.assertAllClose(np.asarray(actual.device_buffers[0]), expected,
check_dtypes=False) check_dtypes=False)
@jtu.with_mesh([('x', 2), ('y', 2)]) @jtu.with_mesh([('x', 2), ('y', 2)])
@ -232,13 +232,13 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertLen(actual.device_buffers, 4) self.assertLen(actual.device_buffers, 4)
split0, split1 = np.split(expected, 2) split0, split1 = np.split(expected, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), split0, self.assertAllClose(np.asarray(actual.device_buffers[0]), split0,
check_dtypes=False) check_dtypes=False)
self.assertAllClose(actual.device_buffers[1].to_py(), split0, self.assertAllClose(np.asarray(actual.device_buffers[1]), split0,
check_dtypes=False) check_dtypes=False)
self.assertAllClose(actual.device_buffers[2].to_py(), split1, self.assertAllClose(np.asarray(actual.device_buffers[2]), split1,
check_dtypes=False) check_dtypes=False)
self.assertAllClose(actual.device_buffers[3].to_py(), split1, self.assertAllClose(np.asarray(actual.device_buffers[3]), split1,
check_dtypes=False) check_dtypes=False)
def testBasic2DWithMeshContextManager(self): def testBasic2DWithMeshContextManager(self):
@ -261,13 +261,13 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertLen(actual.device_buffers, 4) self.assertLen(actual.device_buffers, 4)
split0, split1 = np.split(expected, 2) split0, split1 = np.split(expected, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), split0, self.assertAllClose(np.asarray(actual.device_buffers[0]), split0,
check_dtypes=False) check_dtypes=False)
self.assertAllClose(actual.device_buffers[1].to_py(), split0, self.assertAllClose(np.asarray(actual.device_buffers[1]), split0,
check_dtypes=False) check_dtypes=False)
self.assertAllClose(actual.device_buffers[2].to_py(), split1, self.assertAllClose(np.asarray(actual.device_buffers[2]), split1,
check_dtypes=False) check_dtypes=False)
self.assertAllClose(actual.device_buffers[3].to_py(), split1, self.assertAllClose(np.asarray(actual.device_buffers[3]), split1,
check_dtypes=False) check_dtypes=False)
def testDifferentNestedMesh(self): def testDifferentNestedMesh(self):
@ -318,13 +318,13 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertLen(actual.device_buffers, 4) self.assertLen(actual.device_buffers, 4)
splits = np.split(expected, 4) splits = np.split(expected, 4)
self.assertAllClose(actual.device_buffers[0].to_py(), splits[0], self.assertAllClose(np.asarray(actual.device_buffers[0]), splits[0],
check_dtypes=False) check_dtypes=False)
self.assertAllClose(actual.device_buffers[1].to_py(), splits[1], self.assertAllClose(np.asarray(actual.device_buffers[1]), splits[1],
check_dtypes=False) check_dtypes=False)
self.assertAllClose(actual.device_buffers[2].to_py(), splits[2], self.assertAllClose(np.asarray(actual.device_buffers[2]), splits[2],
check_dtypes=False) check_dtypes=False)
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3], self.assertAllClose(np.asarray(actual.device_buffers[3]), splits[3],
check_dtypes=False) check_dtypes=False)
@jtu.with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
@ -363,7 +363,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False) self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 2) self.assertLen(actual.device_buffers, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), expected, self.assertAllClose(np.asarray(actual.device_buffers[0]), expected,
check_dtypes=False) check_dtypes=False)
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo") hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
@ -390,7 +390,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False) self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, array.Array) self.assertIsInstance(actual, array.Array)
self.assertLen(actual.addressable_shards, 2) self.assertLen(actual.addressable_shards, 2)
self.assertAllClose(actual._arrays[0].to_py(), expected, self.assertAllClose(np.asarray(actual._arrays[0]), expected,
check_dtypes=False) check_dtypes=False)
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo") hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
@ -419,7 +419,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False) self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, array.Array) self.assertIsInstance(actual, array.Array)
self.assertLen(actual.addressable_shards, 2) self.assertLen(actual.addressable_shards, 2)
self.assertAllClose(actual._arrays[0].to_py(), expected, self.assertAllClose(np.asarray(actual._arrays[0]), expected,
check_dtypes=False) check_dtypes=False)
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo") hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
@ -842,13 +842,13 @@ class PJitTest(jtu.BufferDonationTestCase):
((jax.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {})) ((jax.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))
splits = np.split(expected, 4) splits = np.split(expected, 4)
self.assertAllClose(actual.device_buffers[0].to_py(), splits[0], self.assertAllClose(np.asarray(actual.device_buffers[0]), splits[0],
check_dtypes=False) check_dtypes=False)
self.assertAllClose(actual.device_buffers[1].to_py(), splits[1], self.assertAllClose(np.asarray(actual.device_buffers[1]), splits[1],
check_dtypes=False) check_dtypes=False)
self.assertAllClose(actual.device_buffers[2].to_py(), splits[2], self.assertAllClose(np.asarray(actual.device_buffers[2]), splits[2],
check_dtypes=False) check_dtypes=False)
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3], self.assertAllClose(np.asarray(actual.device_buffers[3]), splits[3],
check_dtypes=False) check_dtypes=False)
for obj in [lowered, compiled]: for obj in [lowered, compiled]:

View File

@ -3016,7 +3016,7 @@ class ShardArgsTest(jtu.JaxTestCase):
self.assertEqual(len(bufs), 1) self.assertEqual(len(bufs), 1)
self.assertEqual(len(bufs[0]), nshards) self.assertEqual(len(bufs[0]), nshards)
for buf, idx in zip(bufs[0], indices): for buf, idx in zip(bufs[0], indices):
self.assertAllClose(buf.to_py(), x[idx], check_dtypes=False) self.assertAllClose(np.asarray(buf), x[idx], check_dtypes=False)
class ArrayPmapTest(jtu.JaxTestCase): class ArrayPmapTest(jtu.JaxTestCase):