[JAX] Deprecate .to_py() property on arrays. Implement __array__ instead.

.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant.

PiperOrigin-RevId: 469984029
This commit is contained in:
Peter Hawkins 2022-08-25 07:27:54 -07:00 committed by jax authors
parent fd3a72dd1f
commit 5527966b27
17 changed files with 89 additions and 75 deletions

View File

@ -22,6 +22,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Deprecations:
* The deprecated `DeviceArray.tile()` method has been removed. Use {func}`jax.numpy.tile`
({jax-issue}`#11944`).
* `DeviceArray.to_py()` has been deprecated. Use `np.asarray(x)` instead.
## jax 0.3.16
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).

View File

@ -2066,7 +2066,7 @@
"\n",
"def handle_result(aval: ShapedArray, buf):\n",
" del aval # Unused for now\n",
" return buf.to_py()\n",
" return np.asarray(buf)\n",
"\n",
"xla_translations = {}"
]
@ -2370,9 +2370,9 @@
" shape = property(lambda self: self.aval.shape)\n",
" ndim = property(lambda self: self.aval.ndim)\n",
"\n",
" def __array__(self): return self.buf.to_py()\n",
" def __repr__(self): return repr(self.buf.to_py())\n",
" def __str__(self): return str(self.buf.to_py())\n",
" def __array__(self): return np.asarray(self.buf)\n",
" def __repr__(self): return repr(np.asarray(self.buf))\n",
" def __str__(self): return str(np.asarray(self.buf))\n",
"\n",
" _neg = staticmethod(neg)\n",
" _add = staticmethod(add)\n",

View File

@ -1626,7 +1626,7 @@ input_handlers = {ty: default_input_handler for ty in
def handle_result(aval: ShapedArray, buf):
del aval # Unused for now
return buf.to_py()
return np.asarray(buf)
xla_translations = {}
```
@ -1842,9 +1842,9 @@ class DeviceArray:
shape = property(lambda self: self.aval.shape)
ndim = property(lambda self: self.aval.ndim)
def __array__(self): return self.buf.to_py()
def __repr__(self): return repr(self.buf.to_py())
def __str__(self): return str(self.buf.to_py())
def __array__(self): return np.asarray(self.buf)
def __repr__(self): return repr(np.asarray(self.buf))
def __str__(self): return str(np.asarray(self.buf))
_neg = staticmethod(neg)
_add = staticmethod(add)

View File

@ -1620,7 +1620,7 @@ input_handlers = {ty: default_input_handler for ty in
def handle_result(aval: ShapedArray, buf):
del aval # Unused for now
return buf.to_py()
return np.asarray(buf)
xla_translations = {}
@ -1833,9 +1833,9 @@ class DeviceArray:
shape = property(lambda self: self.aval.shape)
ndim = property(lambda self: self.aval.ndim)
def __array__(self): return self.buf.to_py()
def __repr__(self): return repr(self.buf.to_py())
def __str__(self): return str(self.buf.to_py())
def __array__(self): return np.asarray(self.buf)
def __repr__(self): return repr(np.asarray(self.buf))
def __str__(self): return str(np.asarray(self.buf))
_neg = staticmethod(neg)
_add = staticmethod(add)

View File

@ -18,6 +18,7 @@ from functools import partial, partialmethod
import operator
from typing import (Any, List, Optional, Union)
import weakref
import warnings
import numpy as np
@ -146,7 +147,7 @@ class _DeviceArray(DeviceArray): # type: ignore
def _value(self):
self._check_if_deleted()
if self._npy_value is None:
self._npy_value = self.device_buffer.to_py() # pytype: disable=attribute-error # bind-properties
self._npy_value = np.asarray(self.device_buffer) # pytype: disable=attribute-error # bind-properties
self._npy_value.flags.writeable = False
return self._npy_value
@ -266,6 +267,15 @@ for device_array in [DeviceArray]:
setattr(device_array, "__array__", __array__)
# TODO(phawkins): delete this code path after the deprecation for .to_py()
# expires in Nov 2022.
def to_py(self):
warnings.warn("The .to_py() method on JAX arrays is deprecated. Use "
"np.asarray(...) instead.", category=FutureWarning)
return np.asarray(self._value)
setattr(device_array, "to_py", to_py)
def __dlpack__(self):
from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
return to_dlpack(self)

View File

@ -779,9 +779,9 @@ def check_special(name, bufs):
def _check_special(name, xla_shape, buf):
assert not xla_shape.is_tuple()
if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
if config.jax_debug_nans and np.any(np.isnan(buf.to_py())):
if config.jax_debug_nans and np.any(np.isnan(np.asarray(buf))):
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
if config.jax_debug_infs and np.any(np.isinf(buf.to_py())):
if config.jax_debug_infs and np.any(np.isinf(np.asarray(buf))):
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
@ -1155,7 +1155,7 @@ def _copy_device_array_to_device(
else:
# buffers from different XLA backends are passed through the host.
backend = xb.get_device_backend(device)
moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
moved_buf = backend.buffer_from_pyval(np.asarray(x.device_buffer), device)
return device_array.make_device_array(x.aval, device, moved_buf)
@ -1182,7 +1182,7 @@ def _copy_array_to_device(x: Array, device: Optional[xc.Device]) -> Array:
else:
# buffers from different XLA backends are passed through the host.
backend = xb.get_device_backend(device)
moved_buf = backend.buffer_from_pyval(buf.to_py(), device)
moved_buf = backend.buffer_from_pyval(np.asarray(buf), device)
return array.Array(
x.aval, sharding.SingleDeviceSharding(moved_buf.device()), [moved_buf],
committed=(device is not None))

View File

@ -88,7 +88,7 @@ class IreeBuffer(xla_client.DeviceArrayBase):
def copy_to_device(self, device):
return self
def to_py(self) -> np.ndarray:
def __array__(self, dtype=None, context=None):
return np.asarray(self._buffer)
def to_iree(self):
@ -104,8 +104,11 @@ class IreeBuffer(xla_client.DeviceArrayBase):
return self # no async
# overrides repr on base class which expects _value and aval attributes
def __repr__(self): return f'IreeBuffer({self.to_py()})'
_value = property(to_py)
def __repr__(self): return f'IreeBuffer({np.asarray(self)})'
@property
def _value(self):
return np.asarray(self)
class IreeExecutable:

View File

@ -290,9 +290,6 @@ class Array:
self._check_if_deleted()
return list(self.sharding.device_set)
def to_py(self) -> np.ndarray:
return self._value
@pxla.maybe_cached_property
def addressable_shards(self) -> Sequence[Shard]:
self._check_if_deleted()
@ -367,7 +364,7 @@ class Array:
for s in self.addressable_shards:
if not replica_id_exists or s.replica_id == 0:
npy_value[s.index] = s.data._arrays[0].to_py() # type: ignore # [union-attr]
npy_value[s.index] = np.asarray(s.data._arrays[0]) # type: ignore # [union-attr]
self._npy_value = npy_value # type: ignore
# https://docs.python.org/3/library/typing.html#typing.cast
return cast(np.ndarray, self._npy_value)

View File

@ -72,16 +72,16 @@ class CheckpointTest(jtu.JaxTestCase):
[mesh_axes, P('x'), P(None)],
tspecs)
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
self.assertArraysEqual(np.asarray(m1.local_shards[0].data),
np.array([[0], [2]]))
self.assertArraysEqual(m1.local_shards[1].data.to_py(),
self.assertArraysEqual(np.asarray(m1.local_shards[1].data),
np.array([[1], [3]]))
self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
self.assertEqual(m1.dtype, np.int32)
self.assertArraysEqual(m2.local_shards[0].data.to_py(),
self.assertArraysEqual(np.asarray(m2.local_shards[0].data),
np.array([[16, 17], [18, 19]]))
self.assertArraysEqual(m2.local_shards[1].data.to_py(),
self.assertArraysEqual(np.asarray(m2.local_shards[1].data),
np.array([[16, 17], [18, 19]]))
self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
self.assertEqual(m2.dtype, np.int32)
@ -89,7 +89,7 @@ class CheckpointTest(jtu.JaxTestCase):
for i, s in enumerate(m3.local_shards):
self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i)
self.assertArraysEqual(s.data.to_py(), np.array([]))
self.assertArraysEqual(np.asarray(s.data), np.array([]))
self.assertEqual(m3.dtype, np.float32)
@jax_config.jax_array(True)
@ -132,17 +132,17 @@ class CheckpointTest(jtu.JaxTestCase):
tspecs)
self.assertIsInstance(m1, array.Array)
self.assertArraysEqual(m1.addressable_shards[0].data.to_py(),
self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data),
np.array([[0], [2]]))
self.assertArraysEqual(m1.addressable_shards[1].data.to_py(),
self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data),
np.array([[1], [3]]))
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
self.assertEqual(m1.dtype, np.int32)
self.assertIsInstance(m2, array.Array)
self.assertArraysEqual(m2.addressable_shards[0].data.to_py(),
self.assertArraysEqual(np.asarray(m2.addressable_shards[0].data),
np.array([[16, 17], [18, 19]]))
self.assertArraysEqual(m2.addressable_shards[1].data.to_py(),
self.assertArraysEqual(np.asarray(m2.addressable_shards[1].data),
np.array([[16, 17], [18, 19]]))
self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
self.assertEqual(m2.dtype, np.int32)
@ -151,7 +151,7 @@ class CheckpointTest(jtu.JaxTestCase):
for i, s in enumerate(m3.addressable_shards):
self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i)
self.assertArraysEqual(s.data.to_py(), np.array([]))
self.assertArraysEqual(np.asarray(s.data), np.array([]))
self.assertEqual(m3.dtype, np.float32)
def test_checkpointing_with_bigger_shape(self):
@ -192,7 +192,7 @@ class CheckpointTest(jtu.JaxTestCase):
}
for l in m1.local_shards:
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
def test_checkpointing_scalar(self):
global_mesh = jtu.create_global_mesh((2,), ('x'))
@ -216,7 +216,7 @@ class CheckpointTest(jtu.JaxTestCase):
)
for l in m1.local_shards:
self.assertArraysEqual(l.data.to_py(), data.astype(np.float32))
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
def test_spec_has_metadata(self):
spec = {

View File

@ -386,9 +386,12 @@ class GlobalDeviceArray:
for s in self.local_shards if s.replica_id == 0]
npy_value = np.empty(self.shape, self.dtype)
for s in unique_shards:
npy_value[s.index] = s.data.to_py()
npy_value[s.index] = np.asarray(s.data)
return npy_value
def __array__(self, dtype=None, context=None):
return self._value if dtype is None else self._value.astype(dtype)
def local_data(self, index) -> DeviceArray:
return pxla._set_aval(self._device_buffers[index])

View File

@ -103,7 +103,7 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
def _pjit(inp):
if isinstance(inp, GlobalDeviceArray):
if inp.is_fully_replicated:
return inp.local_data(0).to_py()
return np.asarray(inp.local_data(0))
global_mesh = inp.mesh
in_axis_resources = FROM_GDA
else:
@ -119,7 +119,7 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
out = pjit(lambda x: x, in_axis_resources=in_axis_resources,
out_axis_resources=None)(inp)
return out.local_data(0).to_py()
return np.asarray(out.local_data(0))
with config_internal.parallel_functions_output_gda(True):
return jax.tree_util.tree_map(_pjit, in_tree)

View File

@ -303,7 +303,7 @@ for ptype, dtype in dtypes.python_scalar_dtypes.items():
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
def _device_array_constant_handler(val, canonicalize_types):
return _ndarray_constant_handler(val.device_buffer.to_py(),
return _ndarray_constant_handler(np.asarray(val.device_buffer),
canonicalize_types)
for t in device_array.device_array_types:
register_constant_handler(t, _device_array_constant_handler)

View File

@ -699,7 +699,7 @@ class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore
precomputed for efficiency. A list the same length as
`device_buffers`. Each index indicates what portion of the full array is
stored in the corresponding device buffer, i.e. `array[indices[i]] ==
device_buffers[i].to_py()`.
np.asarray(device_buffers[i])`.
"""
__slots__ = [
"aval", "device_buffers", "sharding_spec", "indices",
@ -792,7 +792,7 @@ def _sda_value(self):
self.copy_to_host_async()
npy_value = np.empty(self.aval.shape, self.aval.dtype)
for i in self.one_replica_buffer_indices:
npy_value[self.indices[i]] = self.device_buffers[i].to_py()
npy_value[self.indices[i]] = np.asarray(self.device_buffers[i])
self._npy_value = npy_value
return self._npy_value

View File

@ -194,7 +194,7 @@ class GDATest(jtu.JaxTestCase):
for i, s in enumerate(gda.local_shards):
self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i)
self.assertArraysEqual(s.data.to_py(), np.array([]))
self.assertArraysEqual(np.asarray(s.data), np.array([]))
self.assertEqual(gda.dtype, np.float32)
self.assertEqual(
gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes),
@ -249,10 +249,10 @@ class GDATest(jtu.JaxTestCase):
gda = GlobalDeviceArray.from_batched_callback(
global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1]])
self.assertArraysEqual(gda.local_data(0).to_py(),
self.assertArraysEqual(np.asarray(gda.local_data(0)),
expected_first_shard_value)
expected_second_shard_value = np.array([[2, 3]])
self.assertArraysEqual(gda.local_data(1).to_py(),
self.assertArraysEqual(np.asarray(gda.local_data(1)),
expected_second_shard_value)
def test_gda_batched_callback_with_devices(self):
@ -275,10 +275,10 @@ class GDATest(jtu.JaxTestCase):
gda = GlobalDeviceArray.from_batched_callback_with_devices(
global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(gda.local_data(0).to_py(),
self.assertArraysEqual(np.asarray(gda.local_data(0)),
expected_first_shard_value)
expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(gda.local_data(1).to_py(),
self.assertArraysEqual(np.asarray(gda.local_data(1)),
expected_second_shard_value)
def test_gda_str_repr(self):

View File

@ -113,7 +113,7 @@ class JaxJitTest(jtu.JaxTestCase):
complex_type = dtypes.canonicalize_dtype(np.complex128)
# int
res = _cpp_device_put(1, device).to_py()
res = np.asarray(_cpp_device_put(1, device))
self.assertEqual(res, 1)
self.assertEqual(res.dtype, int_type)
# We also compare to the Python Jax API, to make sure we have the exact
@ -122,20 +122,20 @@ class JaxJitTest(jtu.JaxTestCase):
self.assertEqual(jnp.asarray(1).dtype, res.dtype)
# float
res = _cpp_device_put(1.0, device).to_py()
res = np.asarray(_cpp_device_put(1.0, device))
self.assertEqual(res, 1.0)
self.assertEqual(res.dtype, float_type)
self.assertEqual(jnp.asarray(1.0).dtype, res.dtype)
# bool
for bool_value in [True, False]:
res = _cpp_device_put(bool_value, device).to_py()
res = np.asarray(_cpp_device_put(bool_value, device))
self.assertEqual(res, np.asarray(bool_value))
self.assertEqual(res.dtype, np.bool_)
self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype)
# Complex
res = _cpp_device_put(1 + 1j, device).to_py()
res = np.asarray(_cpp_device_put(1 + 1j, device))
self.assertEqual(res, 1 + 1j)
self.assertEqual(res.dtype, complex_type)
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)

View File

@ -135,7 +135,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 1)
self.assertAllClose(
actual.device_buffers[0].to_py(), expected, check_dtypes=False)
np.asarray(actual.device_buffers[0]), expected, check_dtypes=False)
# Repro for a bug on device_buffer aval
_ = repr(actual.device_buffers)
@ -154,7 +154,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
self.assertAllClose(np.asarray(actual.device_buffers[0]), expected,
check_dtypes=False)
@jtu.with_mesh([('x', 2)])
@ -191,7 +191,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual[:3], expected[:3], check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 2)
self.assertAllClose(actual.device_buffers[0].to_py()[:3], expected[:3],
self.assertAllClose(np.asarray(actual.device_buffers[0])[:3], expected[:3],
check_dtypes=False)
def testBasic1DWithMeshContextManager(self):
@ -210,7 +210,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
self.assertAllClose(np.asarray(actual.device_buffers[0]), expected,
check_dtypes=False)
@jtu.with_mesh([('x', 2), ('y', 2)])
@ -232,13 +232,13 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertLen(actual.device_buffers, 4)
split0, split1 = np.split(expected, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), split0,
self.assertAllClose(np.asarray(actual.device_buffers[0]), split0,
check_dtypes=False)
self.assertAllClose(actual.device_buffers[1].to_py(), split0,
self.assertAllClose(np.asarray(actual.device_buffers[1]), split0,
check_dtypes=False)
self.assertAllClose(actual.device_buffers[2].to_py(), split1,
self.assertAllClose(np.asarray(actual.device_buffers[2]), split1,
check_dtypes=False)
self.assertAllClose(actual.device_buffers[3].to_py(), split1,
self.assertAllClose(np.asarray(actual.device_buffers[3]), split1,
check_dtypes=False)
def testBasic2DWithMeshContextManager(self):
@ -261,13 +261,13 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertLen(actual.device_buffers, 4)
split0, split1 = np.split(expected, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), split0,
self.assertAllClose(np.asarray(actual.device_buffers[0]), split0,
check_dtypes=False)
self.assertAllClose(actual.device_buffers[1].to_py(), split0,
self.assertAllClose(np.asarray(actual.device_buffers[1]), split0,
check_dtypes=False)
self.assertAllClose(actual.device_buffers[2].to_py(), split1,
self.assertAllClose(np.asarray(actual.device_buffers[2]), split1,
check_dtypes=False)
self.assertAllClose(actual.device_buffers[3].to_py(), split1,
self.assertAllClose(np.asarray(actual.device_buffers[3]), split1,
check_dtypes=False)
def testDifferentNestedMesh(self):
@ -318,13 +318,13 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertLen(actual.device_buffers, 4)
splits = np.split(expected, 4)
self.assertAllClose(actual.device_buffers[0].to_py(), splits[0],
self.assertAllClose(np.asarray(actual.device_buffers[0]), splits[0],
check_dtypes=False)
self.assertAllClose(actual.device_buffers[1].to_py(), splits[1],
self.assertAllClose(np.asarray(actual.device_buffers[1]), splits[1],
check_dtypes=False)
self.assertAllClose(actual.device_buffers[2].to_py(), splits[2],
self.assertAllClose(np.asarray(actual.device_buffers[2]), splits[2],
check_dtypes=False)
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
self.assertAllClose(np.asarray(actual.device_buffers[3]), splits[3],
check_dtypes=False)
@jtu.with_mesh([('x', 2)])
@ -363,7 +363,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
self.assertAllClose(np.asarray(actual.device_buffers[0]), expected,
check_dtypes=False)
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
@ -390,7 +390,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, array.Array)
self.assertLen(actual.addressable_shards, 2)
self.assertAllClose(actual._arrays[0].to_py(), expected,
self.assertAllClose(np.asarray(actual._arrays[0]), expected,
check_dtypes=False)
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
@ -419,7 +419,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, array.Array)
self.assertLen(actual.addressable_shards, 2)
self.assertAllClose(actual._arrays[0].to_py(), expected,
self.assertAllClose(np.asarray(actual._arrays[0]), expected,
check_dtypes=False)
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
@ -842,13 +842,13 @@ class PJitTest(jtu.BufferDonationTestCase):
((jax.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))
splits = np.split(expected, 4)
self.assertAllClose(actual.device_buffers[0].to_py(), splits[0],
self.assertAllClose(np.asarray(actual.device_buffers[0]), splits[0],
check_dtypes=False)
self.assertAllClose(actual.device_buffers[1].to_py(), splits[1],
self.assertAllClose(np.asarray(actual.device_buffers[1]), splits[1],
check_dtypes=False)
self.assertAllClose(actual.device_buffers[2].to_py(), splits[2],
self.assertAllClose(np.asarray(actual.device_buffers[2]), splits[2],
check_dtypes=False)
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
self.assertAllClose(np.asarray(actual.device_buffers[3]), splits[3],
check_dtypes=False)
for obj in [lowered, compiled]:

View File

@ -3016,7 +3016,7 @@ class ShardArgsTest(jtu.JaxTestCase):
self.assertEqual(len(bufs), 1)
self.assertEqual(len(bufs[0]), nshards)
for buf, idx in zip(bufs[0], indices):
self.assertAllClose(buf.to_py(), x[idx], check_dtypes=False)
self.assertAllClose(np.asarray(buf), x[idx], check_dtypes=False)
class ArrayPmapTest(jtu.JaxTestCase):